diff --git a/pkg-py/src/querychat/_querychat.py b/pkg-py/src/querychat/_querychat.py index 3e4f59f5..52136587 100644 --- a/pkg-py/src/querychat/_querychat.py +++ b/pkg-py/src/querychat/_querychat.py @@ -40,6 +40,7 @@ def __init__( greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, data_description: Optional[str | Path] = None, + categorical_threshold: int = 10, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, ): @@ -79,6 +80,9 @@ def __init__( Description of the data in plain text or Markdown. If a pathlib.Path object is passed, querychat will read the contents of the path into a string with `.read_text()`. + categorical_threshold + Threshold for determining if a column is categorical based on number of + unique values. extra_instructions Additional instructions for the chat model. If a pathlib.Path object is passed, querychat will read the contents of the path into a string with @@ -104,7 +108,7 @@ def __init__( ``` """ - self.data_source = normalize_data_source(data_source, table_name) + self._data_source = normalize_data_source(data_source, table_name) # Validate table name (must begin with letter, contain only letters, numbers, underscores) if not re.match(r"^[a-zA-Z][a-zA-Z0-9_]*$", table_name): @@ -114,8 +118,6 @@ def __init__( self.id = id or table_name - self.client = normalize_client(client) - if greeting is None: print( "Warning: No greeting provided; the LLM will be invoked at conversation start to generate one. " @@ -126,13 +128,20 @@ def __init__( self.greeting = greeting.read_text() if isinstance(greeting, Path) else greeting - self.system_prompt = get_system_prompt( - self.data_source, + prompt = get_system_prompt( + self._data_source, data_description=data_description, extra_instructions=extra_instructions, + categorical_threshold=categorical_threshold, prompt_template=prompt_template, ) + # Fork and empty chat now so the per-session forks are fast + client = normalize_client(client) + self._client = copy.deepcopy(client) + self._client.set_turns([]) + self._client.system_prompt = prompt + # Populated when ._server() gets called (in an active session) self._server_values: ModServerResult | None = None @@ -160,7 +169,7 @@ def app( """ enable_bookmarking = bookmark_store != "disable" - table_name = self.data_source.table_name + table_name = self._data_source.table_name def app_ui(request): return ui.page_sidebar( @@ -303,10 +312,9 @@ def _server(self, *, enable_bookmarking: bool = False) -> None: # Call the server module self._server_values = mod_server( self.id, - data_source=self.data_source, - system_prompt=self.system_prompt, + data_source=self._data_source, greeting=self.greeting, - client=self.client, + client=self._client, enable_bookmarking=enable_bookmarking, ) @@ -416,7 +424,7 @@ def title(self, value: Optional[str] = None) -> str | None | bool: else: return vals.title.set(value) - def generate_greeting(self, *, echo: Literal["none", "text"] = "none"): + def generate_greeting(self, *, echo: Literal["none", "output"] = "none"): """ Generate a welcome greeting for the chat. @@ -428,7 +436,7 @@ def generate_greeting(self, *, echo: Literal["none", "text"] = "none"): Parameters ---------- echo - If `echo = "text"`, prints the greeting to standard output. If + If `echo = "output"`, prints the greeting to standard output. If `echo = "none"` (default), does not print anything. Returns @@ -437,97 +445,39 @@ def generate_greeting(self, *, echo: Literal["none", "text"] = "none"): The greeting string (in Markdown format). """ - client = copy.deepcopy(self.client) - client.system_prompt = self.system_prompt + client = copy.deepcopy(self._client) client.set_turns([]) prompt = "Please give me a friendly greeting. Include a few sample prompts in a two-level bulleted list." return str(client.chat(prompt, echo=echo)) - def set_system_prompt( - self, - data_source: DataSource, - *, - data_description: Optional[str | Path] = None, - extra_instructions: Optional[str | Path] = None, - categorical_threshold: int = 10, - prompt_template: Optional[str | Path] = None, - ) -> None: - """ - Customize the system prompt. - - Control the logic behind how the system prompt is generated based on the - data source's schema and optional additional context and instructions. - - Note - ---- - This method is for parametrized system prompt generation only. To set a - fully custom system prompt string, set the `system_prompt` attribute - directly. - - Parameters - ---------- - data_source - A data source to generate schema information from - data_description - Optional description of the data, in plain text or Markdown format - extra_instructions - Optional additional instructions for the chat model, in plain text or - Markdown format - categorical_threshold - Threshold for determining if a column is categorical based on number of - unique values - prompt_template - Optional `Path` to or string of a custom prompt template. If not provided, the default - querychat template will be used. - - """ - self.system_prompt = get_system_prompt( - data_source, - data_description=data_description, - extra_instructions=extra_instructions, - categorical_threshold=categorical_threshold, - prompt_template=prompt_template, - ) - - def set_data_source( - self, data_source: IntoFrame | sqlalchemy.Engine | DataSource, table_name: str - ) -> None: + @property + def client(self): """ - Set a new data source for the QueryChat object. - - Parameters - ---------- - data_source - The new data source to use. - table_name - If a data_source is a data frame, a name to use to refer to the table + Get the (session-specific) chat client. Returns ------- : - None + The current chat client. """ - self.data_source = normalize_data_source(data_source, table_name) + vals = self._server_values + if vals is None: + raise RuntimeError("Must call .server() before accessing .client") + return vals.client - def set_client(self, client: str | chatlas.Chat) -> None: + @property + def data_source(self): """ - Set a new chat client for the QueryChat object. - - Parameters - ---------- - client - A `chatlas.Chat` object or a string to be passed to - `chatlas.ChatAuto()` describing the model to use (e.g. - `"openai/gpt-4.1"`). + Get the current data source. Returns ------- : - None + The current data source. """ - self.client = normalize_client(client) + return self._data_source class QueryChat(QueryChatBase): diff --git a/pkg-py/src/querychat/_querychat_module.py b/pkg-py/src/querychat/_querychat_module.py index 6153a5b8..4fd83570 100644 --- a/pkg-py/src/querychat/_querychat_module.py +++ b/pkg-py/src/querychat/_querychat_module.py @@ -58,7 +58,6 @@ def mod_server( session: Session, *, data_source: DataSource, - system_prompt: str, greeting: str | None, client: chatlas.Chat, enable_bookmarking: bool, @@ -70,8 +69,6 @@ def mod_server( # Set up the chat object for this session chat = copy.deepcopy(client) - chat.set_turns([]) - chat.system_prompt = system_prompt # Create the tool functions update_dashboard_tool = tool_update_dashboard(data_source, sql, title) diff --git a/pkg-py/tests/test_init_with_pandas.py b/pkg-py/tests/test_init_with_pandas.py index 3f94b639..7c182639 100644 --- a/pkg-py/tests/test_init_with_pandas.py +++ b/pkg-py/tests/test_init_with_pandas.py @@ -39,10 +39,6 @@ def test_init_with_pandas_dataframe(): # Verify the result is properly configured assert qc is not None - assert hasattr(qc, "data_source") - assert hasattr(qc, "system_prompt") - assert hasattr(qc, "greeting") - assert hasattr(qc, "client") def test_init_with_narwhals_dataframe(): @@ -66,8 +62,6 @@ def test_init_with_narwhals_dataframe(): # Verify the result is correctly configured assert qc is not None - assert hasattr(qc, "data_source") - assert hasattr(qc, "system_prompt") def test_init_with_narwhals_lazyframe_direct_query(): diff --git a/pkg-py/tests/test_querychat.py b/pkg-py/tests/test_querychat.py index b8267b46..5ffce90b 100644 --- a/pkg-py/tests/test_querychat.py +++ b/pkg-py/tests/test_querychat.py @@ -39,10 +39,6 @@ def test_querychat_init(sample_df): # Verify basic attributes are set assert qc is not None - assert hasattr(qc, "data_source") - assert hasattr(qc, "system_prompt") - assert hasattr(qc, "greeting") - assert hasattr(qc, "client") assert qc.id == "test_table" # Even without server initialization, we should be able to query the data source @@ -66,31 +62,6 @@ def test_querychat_custom_id(sample_df): assert qc.id == "custom_id" -def test_querychat_set_methods(sample_df): - """Test that setter methods work.""" - qc = QueryChat( - data_source=sample_df, - table_name="test_table", - greeting="Hello!", - ) - - # Test set_system_prompt - qc.set_system_prompt( - qc.data_source, - data_description="A test dataset", - ) - assert "test dataset" in qc.system_prompt.lower() - - # Test set_data_source - new_df = pd.DataFrame({"x": [1, 2, 3]}) - qc.set_data_source(new_df, "new_table") - assert qc.data_source is not None - - # Test set_client - qc.set_client("openai/gpt-4o-mini") - assert qc.client is not None - - def test_querychat_core_reactive_access_before_server_raises(sample_df): """Test that accessing reactive properties before .server() raises error.""" qc = QueryChat(