Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions pkg-py/src/querychat/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from typing import TYPE_CHECKING, ClassVar, Protocol

import duckdb
import narwhals as nw
import narwhals.stable.v1 as nw
import pandas as pd
from sqlalchemy import inspect, text
from sqlalchemy.sql import sqltypes

if TYPE_CHECKING:
from narwhals.stable.v1.typing import IntoFrame
from sqlalchemy.engine import Connection, Engine


Expand Down Expand Up @@ -58,8 +59,9 @@ class DataFrameSource:
"""A DataSource implementation that wraps a pandas DataFrame using DuckDB."""

db_engine: ClassVar[str] = "DuckDB"
_df: nw.DataFrame | nw.LazyFrame

def __init__(self, df: pd.DataFrame, table_name: str):
def __init__(self, df: IntoFrame, table_name: str):
"""
Initialize with a pandas DataFrame.

Expand All @@ -69,9 +71,10 @@ def __init__(self, df: pd.DataFrame, table_name: str):

"""
self._conn = duckdb.connect(database=":memory:")
self._df = df
self._df = nw.from_native(df)
self._table_name = table_name
self._conn.register(table_name, df)
# TODO(@gadenbuie): If the data frame is already SQL-backed, maybe we shouldn't be making a new copy here.
self._conn.register(table_name, self._df.lazy().collect().to_pandas())

def get_schema(self, *, categorical_threshold: int) -> str:
"""
Expand All @@ -86,10 +89,15 @@ def get_schema(self, *, categorical_threshold: int) -> str:
String describing the schema

"""
ndf = nw.from_native(self._df)

schema = [f"Table: {self._table_name}", "Columns:"]

# Ensure we're working with a DataFrame, not a LazyFrame
ndf = (
self._df.head(10).collect()
if isinstance(self._df, nw.LazyFrame)
else self._df
)

for column in ndf.columns:
# Map pandas dtypes to SQL-like types
dtype = ndf[column].dtype
Expand Down Expand Up @@ -149,7 +157,8 @@ def get_data(self) -> pd.DataFrame:
The complete dataset as a pandas DataFrame

"""
return self._df.copy()
# TODO(@gadenbuie): This should just return `self._df` and not a pandas DataFrame
return self._df.lazy().collect().to_pandas()


class SQLAlchemySource:
Expand Down
9 changes: 3 additions & 6 deletions pkg-py/src/querychat/querychat.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,15 +414,12 @@ def init(
data_source_obj: DataSource
if isinstance(data_source, sqlalchemy.Engine):
data_source_obj = SQLAlchemySource(data_source, table_name)
elif isinstance(data_source, (nw.DataFrame, nw.LazyFrame)):
else:
data_source_obj = DataFrameSource(
nw.to_native(data_source),
data_source,
table_name,
)
else:
raise TypeError(
"`data_source` must be a Narwhals DataFrame or LazyFrame, or a SQLAlchemy Engine",
)

# Process greeting
if greeting is None:
print(
Expand Down
3 changes: 2 additions & 1 deletion pkg-py/tests/test_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import pytest
from sqlalchemy import create_engine, text
from src.querychat.datasource import SQLAlchemySource

from querychat.datasource import SQLAlchemySource


@pytest.fixture
Expand Down
5 changes: 3 additions & 2 deletions pkg-py/tests/test_df_to_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import pandas as pd
import pytest
from sqlalchemy import create_engine
from src.querychat.datasource import DataFrameSource, SQLAlchemySource
from src.querychat.querychat import df_to_html

from querychat.datasource import DataFrameSource, SQLAlchemySource
from querychat.querychat import df_to_html


@pytest.fixture
Expand Down
102 changes: 102 additions & 0 deletions pkg-py/tests/test_init_with_pandas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os

import narwhals.stable.v1 as nw
import pandas as pd
import pytest

from querychat.querychat import init


@pytest.fixture(autouse=True)
def set_dummy_api_key():
"""Set a dummy OpenAI API key for testing."""
old_api_key = os.environ.get("OPENAI_API_KEY")
os.environ["OPENAI_API_KEY"] = "sk-dummy-api-key-for-testing"
yield
if old_api_key is not None:
os.environ["OPENAI_API_KEY"] = old_api_key
else:
del os.environ["OPENAI_API_KEY"]


def test_init_with_pandas_dataframe():
"""Test that init() can accept a pandas DataFrame."""
# Create a simple pandas DataFrame
df = pd.DataFrame(
{
"id": [1, 2, 3],
"name": ["Alice", "Bob", "Charlie"],
"age": [25, 30, 35],
},
)

# Call init with the pandas DataFrame - it should not raise errors
# The function should accept a pandas DataFrame even with the narwhals import change
result = init(
data_source=df,
table_name="test_table",
greeting="hello!",
)

# Verify the result is an instance of QueryChatConfig
assert result is not None
assert hasattr(result, "data_source")
assert hasattr(result, "system_prompt")
assert hasattr(result, "greeting")
assert hasattr(result, "client")


def test_init_with_narwhals_dataframe():
"""Test that init() can accept a narwhals DataFrame."""
# Create a pandas DataFrame and convert to narwhals
pdf = pd.DataFrame(
{
"id": [1, 2, 3],
"name": ["Alice", "Bob", "Charlie"],
"age": [25, 30, 35],
},
)
nw_df = nw.from_native(pdf)

# Call init with the narwhals DataFrame - it should not raise errors
result = init(
data_source=nw_df,
table_name="test_table",
greeting="hello!",
)

# Verify the result is correctly configured
assert result is not None
assert hasattr(result, "data_source")
assert hasattr(result, "system_prompt")


def test_init_with_narwhals_lazyframe_direct_query():
"""Test that init() can accept a narwhals LazyFrame and execute queries."""
# Create a pandas DataFrame and convert to narwhals LazyFrame
pdf = pd.DataFrame(
{
"id": [1, 2, 3],
"name": ["Alice", "Bob", "Charlie"],
"age": [25, 30, 35],
},
)
nw_lazy = nw.from_native(pdf).lazy()

# Call init with the narwhals LazyFrame
result = init(
data_source=nw_lazy, # TODO(@gadebuie): Fix this type error
table_name="test_table",
greeting="hello!",
)

# Verify the result is correctly configured
assert result is not None
assert hasattr(result, "data_source")

# Test that we can run a query on the data source
query_result = result.data_source.execute_query(
"SELECT * FROM test_table WHERE id = 2",
)
assert len(query_result) == 1
assert query_result.iloc[0]["name"] == "Bob"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ extend-ignore = [
"D107", # Missing docstring in __init__
"D205", # 1 blank line required between summary line and description
"UP045", # Use `X | NULL` for type annotations, not `Optional[X]`
"TD003", # TODO doesn't need to have an issue link
]
extend-select = [
# "C90", # C90; mccabe: https://docs.astral.sh/ruff/rules/complex-structure/
Expand Down