Skip to content

Commit

Permalink
feat(python): Expose infer_schema_length parameter on `read_databas…
Browse files Browse the repository at this point in the history
…e` (#15076)
  • Loading branch information
alexander-beedie committed Mar 15, 2024
1 parent bc301f2 commit 0abbe5c
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 10 deletions.
7 changes: 3 additions & 4 deletions py-polars/polars/dataframe/frame.py
Expand Up @@ -234,10 +234,9 @@ class DataFrame:
the orientation is inferred by matching the columns and data dimensions. If
this does not yield conclusive results, column orientation is used.
infer_schema_length : int or None
The maximum number of rows to scan for schema inference.
If set to `None`, the full data may be scanned *(this is slow)*.
This parameter only applies if the input data is a sequence or generator of
rows; other input is read as-is.
The maximum number of rows to scan for schema inference. If set to `None`, the
full data may be scanned *(this can be slow)*. This parameter only applies if
the input data is a sequence or generator of rows; other input is read as-is.
nan_to_null : bool, default False
If the data comes from one or more numpy arrays, can optionally convert input
data np.nan values to null instead. This is a no-op for all other input data.
Expand Down
18 changes: 17 additions & 1 deletion py-polars/polars/io/database.py
Expand Up @@ -8,6 +8,7 @@

from polars._utils.deprecation import issue_deprecation_warning
from polars.convert import from_arrow
from polars.datatypes import N_INFER_DEFAULT
from polars.exceptions import InvalidOperationError, UnsuitableSQLError

if TYPE_CHECKING:
Expand Down Expand Up @@ -249,6 +250,7 @@ def _from_arrow(
batch_size: int | None,
iter_batches: bool,
schema_overrides: SchemaDict | None,
infer_schema_length: int | None,
) -> DataFrame | Iterable[DataFrame] | None:
"""Return resultset data in Arrow format for frame init."""
from polars import from_arrow
Expand Down Expand Up @@ -285,6 +287,7 @@ def _from_rows(
batch_size: int | None,
iter_batches: bool,
schema_overrides: SchemaDict | None,
infer_schema_length: int | None,
) -> DataFrame | Iterable[DataFrame] | None:
"""Return resultset data row-wise for frame init."""
from polars import DataFrame
Expand All @@ -304,12 +307,12 @@ def _from_rows(
# TODO: refine types based on the cursor description's type_code,
# if/where available? (for now, we just read the column names)
result_columns = list(cursor_desc)

frames = (
DataFrame(
data=rows,
schema=result_columns,
schema_overrides=schema_overrides,
infer_schema_length=infer_schema_length,
orient="row",
)
for rows in (
Expand Down Expand Up @@ -420,6 +423,7 @@ def to_polars(
iter_batches: bool = False,
batch_size: int | None = None,
schema_overrides: SchemaDict | None = None,
infer_schema_length: int | None = N_INFER_DEFAULT,
) -> DataFrame | Iterable[DataFrame]:
"""
Convert the result set to a DataFrame.
Expand All @@ -444,6 +448,7 @@ def to_polars(
batch_size=batch_size,
iter_batches=iter_batches,
schema_overrides=schema_overrides,
infer_schema_length=infer_schema_length,
)
if frame is not None:
return frame
Expand All @@ -462,6 +467,8 @@ def read_database(
iter_batches: Literal[False] = False,
batch_size: int | None = ...,
schema_overrides: SchemaDict | None = ...,
infer_schema_length: int | None = ...,
execute_options: dict[str, Any] | None = ...,
**kwargs: Any,
) -> DataFrame: ...

Expand All @@ -474,6 +481,8 @@ def read_database(
iter_batches: Literal[True],
batch_size: int | None = ...,
schema_overrides: SchemaDict | None = ...,
infer_schema_length: int | None = ...,
execute_options: dict[str, Any] | None = ...,
**kwargs: Any,
) -> Iterable[DataFrame]: ...

Expand All @@ -485,6 +494,7 @@ def read_database( # noqa: D417
iter_batches: bool = False,
batch_size: int | None = None,
schema_overrides: SchemaDict | None = None,
infer_schema_length: int | None = N_INFER_DEFAULT,
execute_options: dict[str, Any] | None = None,
**kwargs: Any,
) -> DataFrame | Iterable[DataFrame]:
Expand Down Expand Up @@ -523,6 +533,11 @@ def read_database( # noqa: D417
on driver/backend). This can be useful if the given types can be more precisely
defined (for example, if you know that a given column can be declared as `u32`
instead of `i64`).
infer_schema_length
The maximum number of rows to scan for schema inference. If set to `None`, the
full data may be scanned *(this can be slow)*. This parameter only applies if
the data is read as a sequence of rows and the `schema_overrides` parameter
is not set for the given column; Arrow-aware drivers also ignore this value.
execute_options
These options will be passed through into the underlying query execution method
as kwargs. In the case of connections made using an ODBC string (which use
Expand Down Expand Up @@ -657,6 +672,7 @@ def read_database( # noqa: D417
batch_size=batch_size,
iter_batches=iter_batches,
schema_overrides=schema_overrides,
infer_schema_length=infer_schema_length,
)


Expand Down
7 changes: 3 additions & 4 deletions py-polars/polars/lazyframe/frame.py
Expand Up @@ -176,10 +176,9 @@ class LazyFrame:
the orientation is inferred by matching the columns and data dimensions. If
this does not yield conclusive results, column orientation is used.
infer_schema_length : int or None
The maximum number of rows to scan for schema inference.
If set to `None`, the full data may be scanned *(this is slow)*.
This parameter only applies if the input data is a sequence or generator of
rows; other input is read as-is.
The maximum number of rows to scan for schema inference. If set to `None`, the
full data may be scanned *(this can be slow)*. This parameter only applies if
the input data is a sequence or generator of rows; other input is read as-is.
nan_to_null : bool, default False
If the data comes from one or more numpy arrays, can optionally convert input
data np.nan values to null instead. This is a no-op for all other input data.
Expand Down
39 changes: 38 additions & 1 deletion py-polars/tests/unit/io/test_database_read.py
Expand Up @@ -16,7 +16,7 @@
from sqlalchemy.sql.expression import cast as alchemy_cast

import polars as pl
from polars.exceptions import UnsuitableSQLError
from polars.exceptions import ComputeError, UnsuitableSQLError
from polars.io.database import _ARROW_DRIVER_REGISTRY_
from polars.testing import assert_frame_equal

Expand Down Expand Up @@ -78,6 +78,21 @@ def convert_date(val: bytes) -> date:
return test_db


@pytest.fixture()
def tmp_sqlite_inference_db(tmp_path: Path) -> Path:
test_db = tmp_path / "test_inference.db"
test_db.unlink(missing_ok=True)
conn = sqlite3.connect(test_db)
conn.executescript(
"""
CREATE TABLE IF NOT EXISTS test_data (name TEXT, value FLOAT);
REPLACE INTO test_data(name,value) VALUES (NULL,NULL), ('foo',0);
"""
)
conn.close()
return test_db


class DatabaseReadTestParams(NamedTuple):
"""Clarify read test params."""

Expand Down Expand Up @@ -704,6 +719,28 @@ def test_read_database_cx_credentials(uri: str) -> None:
pl.read_database_uri("SELECT * FROM data", uri=uri)


def test_database_infer_schema_length(tmp_sqlite_inference_db: Path) -> None:
# note: first row of this test database contains only NULL values
conn = sqlite3.connect(tmp_sqlite_inference_db)
for infer_len in (2, 100, None):
df = pl.read_database(
connection=conn,
query="SELECT * FROM test_data",
infer_schema_length=infer_len,
)
assert df.schema == {"name": pl.String, "value": pl.Float64}

with pytest.raises(
ComputeError,
match='could not append value: "foo" of type: str.*`infer_schema_length`',
):
pl.read_database(
connection=conn,
query="SELECT * FROM test_data",
infer_schema_length=1,
)


@pytest.mark.write_disk()
def test_read_kuzu_graph_database(tmp_path: Path, io_files_path: Path) -> None:
import kuzu
Expand Down

0 comments on commit 0abbe5c

Please sign in to comment.