From 2f2ad57b42764fa849fdd9ab64eafc5c55574cbd Mon Sep 17 00:00:00 2001 From: Garrick Aden-Buie Date: Thu, 4 Sep 2025 17:41:34 -0400 Subject: [PATCH] fix(pkg-py): Fix data source inputs that are narwhals compatible --- pkg-py/src/querychat/datasource.py | 23 ++++-- pkg-py/src/querychat/querychat.py | 9 +-- pkg-py/tests/test_datasource.py | 3 +- pkg-py/tests/test_df_to_html.py | 5 +- pkg-py/tests/test_init_with_pandas.py | 102 ++++++++++++++++++++++++++ pyproject.toml | 1 + 6 files changed, 127 insertions(+), 16 deletions(-) create mode 100644 pkg-py/tests/test_init_with_pandas.py diff --git a/pkg-py/src/querychat/datasource.py b/pkg-py/src/querychat/datasource.py index 3261e0a9..f04fb955 100644 --- a/pkg-py/src/querychat/datasource.py +++ b/pkg-py/src/querychat/datasource.py @@ -3,12 +3,13 @@ from typing import TYPE_CHECKING, ClassVar, Protocol import duckdb -import narwhals as nw +import narwhals.stable.v1 as nw import pandas as pd from sqlalchemy import inspect, text from sqlalchemy.sql import sqltypes if TYPE_CHECKING: + from narwhals.stable.v1.typing import IntoFrame from sqlalchemy.engine import Connection, Engine @@ -58,8 +59,9 @@ class DataFrameSource: """A DataSource implementation that wraps a pandas DataFrame using DuckDB.""" db_engine: ClassVar[str] = "DuckDB" + _df: nw.DataFrame | nw.LazyFrame - def __init__(self, df: pd.DataFrame, table_name: str): + def __init__(self, df: IntoFrame, table_name: str): """ Initialize with a pandas DataFrame. @@ -69,9 +71,10 @@ def __init__(self, df: pd.DataFrame, table_name: str): """ self._conn = duckdb.connect(database=":memory:") - self._df = df + self._df = nw.from_native(df) self._table_name = table_name - self._conn.register(table_name, df) + # TODO(@gadenbuie): If the data frame is already SQL-backed, maybe we shouldn't be making a new copy here. + self._conn.register(table_name, self._df.lazy().collect().to_pandas()) def get_schema(self, *, categorical_threshold: int) -> str: """ @@ -86,10 +89,15 @@ def get_schema(self, *, categorical_threshold: int) -> str: String describing the schema """ - ndf = nw.from_native(self._df) - schema = [f"Table: {self._table_name}", "Columns:"] + # Ensure we're working with a DataFrame, not a LazyFrame + ndf = ( + self._df.head(10).collect() + if isinstance(self._df, nw.LazyFrame) + else self._df + ) + for column in ndf.columns: # Map pandas dtypes to SQL-like types dtype = ndf[column].dtype @@ -149,7 +157,8 @@ def get_data(self) -> pd.DataFrame: The complete dataset as a pandas DataFrame """ - return self._df.copy() + # TODO(@gadenbuie): This should just return `self._df` and not a pandas DataFrame + return self._df.lazy().collect().to_pandas() class SQLAlchemySource: diff --git a/pkg-py/src/querychat/querychat.py b/pkg-py/src/querychat/querychat.py index f449ede8..5fdcb868 100644 --- a/pkg-py/src/querychat/querychat.py +++ b/pkg-py/src/querychat/querychat.py @@ -414,15 +414,12 @@ def init( data_source_obj: DataSource if isinstance(data_source, sqlalchemy.Engine): data_source_obj = SQLAlchemySource(data_source, table_name) - elif isinstance(data_source, (nw.DataFrame, nw.LazyFrame)): + else: data_source_obj = DataFrameSource( - nw.to_native(data_source), + data_source, table_name, ) - else: - raise TypeError( - "`data_source` must be a Narwhals DataFrame or LazyFrame, or a SQLAlchemy Engine", - ) + # Process greeting if greeting is None: print( diff --git a/pkg-py/tests/test_datasource.py b/pkg-py/tests/test_datasource.py index 734cc4c7..0dd69594 100644 --- a/pkg-py/tests/test_datasource.py +++ b/pkg-py/tests/test_datasource.py @@ -4,7 +4,8 @@ import pytest from sqlalchemy import create_engine, text -from src.querychat.datasource import SQLAlchemySource + +from querychat.datasource import SQLAlchemySource @pytest.fixture diff --git a/pkg-py/tests/test_df_to_html.py b/pkg-py/tests/test_df_to_html.py index 63d306c1..d98f60bb 100644 --- a/pkg-py/tests/test_df_to_html.py +++ b/pkg-py/tests/test_df_to_html.py @@ -5,8 +5,9 @@ import pandas as pd import pytest from sqlalchemy import create_engine -from src.querychat.datasource import DataFrameSource, SQLAlchemySource -from src.querychat.querychat import df_to_html + +from querychat.datasource import DataFrameSource, SQLAlchemySource +from querychat.querychat import df_to_html @pytest.fixture diff --git a/pkg-py/tests/test_init_with_pandas.py b/pkg-py/tests/test_init_with_pandas.py new file mode 100644 index 00000000..8b1a255b --- /dev/null +++ b/pkg-py/tests/test_init_with_pandas.py @@ -0,0 +1,102 @@ +import os + +import narwhals.stable.v1 as nw +import pandas as pd +import pytest + +from querychat.querychat import init + + +@pytest.fixture(autouse=True) +def set_dummy_api_key(): + """Set a dummy OpenAI API key for testing.""" + old_api_key = os.environ.get("OPENAI_API_KEY") + os.environ["OPENAI_API_KEY"] = "sk-dummy-api-key-for-testing" + yield + if old_api_key is not None: + os.environ["OPENAI_API_KEY"] = old_api_key + else: + del os.environ["OPENAI_API_KEY"] + + +def test_init_with_pandas_dataframe(): + """Test that init() can accept a pandas DataFrame.""" + # Create a simple pandas DataFrame + df = pd.DataFrame( + { + "id": [1, 2, 3], + "name": ["Alice", "Bob", "Charlie"], + "age": [25, 30, 35], + }, + ) + + # Call init with the pandas DataFrame - it should not raise errors + # The function should accept a pandas DataFrame even with the narwhals import change + result = init( + data_source=df, + table_name="test_table", + greeting="hello!", + ) + + # Verify the result is an instance of QueryChatConfig + assert result is not None + assert hasattr(result, "data_source") + assert hasattr(result, "system_prompt") + assert hasattr(result, "greeting") + assert hasattr(result, "client") + + +def test_init_with_narwhals_dataframe(): + """Test that init() can accept a narwhals DataFrame.""" + # Create a pandas DataFrame and convert to narwhals + pdf = pd.DataFrame( + { + "id": [1, 2, 3], + "name": ["Alice", "Bob", "Charlie"], + "age": [25, 30, 35], + }, + ) + nw_df = nw.from_native(pdf) + + # Call init with the narwhals DataFrame - it should not raise errors + result = init( + data_source=nw_df, + table_name="test_table", + greeting="hello!", + ) + + # Verify the result is correctly configured + assert result is not None + assert hasattr(result, "data_source") + assert hasattr(result, "system_prompt") + + +def test_init_with_narwhals_lazyframe_direct_query(): + """Test that init() can accept a narwhals LazyFrame and execute queries.""" + # Create a pandas DataFrame and convert to narwhals LazyFrame + pdf = pd.DataFrame( + { + "id": [1, 2, 3], + "name": ["Alice", "Bob", "Charlie"], + "age": [25, 30, 35], + }, + ) + nw_lazy = nw.from_native(pdf).lazy() + + # Call init with the narwhals LazyFrame + result = init( + data_source=nw_lazy, # TODO(@gadebuie): Fix this type error + table_name="test_table", + greeting="hello!", + ) + + # Verify the result is correctly configured + assert result is not None + assert hasattr(result, "data_source") + + # Test that we can run a query on the data source + query_result = result.data_source.execute_query( + "SELECT * FROM test_table WHERE id = 2", + ) + assert len(query_result) == 1 + assert query_result.iloc[0]["name"] == "Bob" diff --git a/pyproject.toml b/pyproject.toml index 773767b0..4b597d0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -124,6 +124,7 @@ extend-ignore = [ "D107", # Missing docstring in __init__ "D205", # 1 blank line required between summary line and description "UP045", # Use `X | NULL` for type annotations, not `Optional[X]` + "TD003", # TODO doesn't need to have an issue link ] extend-select = [ # "C90", # C90; mccabe: https://docs.astral.sh/ruff/rules/complex-structure/