Skip to content

Commit

Permalink
feat(python): support batched frame iteration over read_database qu…
Browse files Browse the repository at this point in the history
…eries (#11664)
  • Loading branch information
alexander-beedie committed Oct 11, 2023
1 parent 987afb8 commit 56a7817
Show file tree
Hide file tree
Showing 3 changed files with 310 additions and 191 deletions.
207 changes: 135 additions & 72 deletions py-polars/polars/io/database.py
Expand Up @@ -3,14 +3,11 @@
import re
import sys
from importlib import import_module
from typing import TYPE_CHECKING, Any, Iterable, Sequence, TypedDict
from typing import TYPE_CHECKING, Any, Iterable, Literal, Sequence, TypedDict, overload

from polars.convert import from_arrow
from polars.exceptions import UnsuitableSQLError
from polars.utils.deprecation import (
deprecate_renamed_parameter,
issue_deprecation_warning,
)
from polars.exceptions import InvalidOperationError, UnsuitableSQLError
from polars.utils.deprecation import issue_deprecation_warning

if TYPE_CHECKING:
from types import TracebackType
Expand Down Expand Up @@ -47,8 +44,8 @@ class _DriverProperties_(TypedDict):
"exact_batch_size": None,
},
"arrow_odbc_proxy": {
"fetch_all": "fetchall",
"fetch_batches": "fetchmany",
"fetch_all": "fetch_record_batches",
"fetch_batches": "fetch_record_batches",
"exact_batch_size": True,
},
"databricks": {
Expand Down Expand Up @@ -104,9 +101,9 @@ def execute(self, query: str, **execute_options: Any) -> None:
self.execute_options = execute_options
self.query = query

def fetchmany(
def fetch_record_batches(
self, batch_size: int = 10_000
) -> Iterable[pa.RecordBatch | pa.Table]:
) -> Iterable[pa.RecordBatch]:
"""Fetch results in batches."""
from arrow_odbc import read_arrow_batches_from_odbc

Expand All @@ -118,7 +115,7 @@ def fetchmany(
)

# internally arrow-odbc always reads batches
fetchall = fetchmany
fetchall = fetchmany = fetch_record_batches


class ConnectionExecutor:
Expand Down Expand Up @@ -173,13 +170,6 @@ def _normalise_cursor(self, conn: ConnectionOrCursor) -> Cursor:
f"Unrecognised connection {conn!r}; unable to find 'execute' method"
)

@staticmethod
def _fetch_arrow(
result: Cursor, fetch_method: str, batch_size: int | None
) -> Iterable[pa.RecordBatch | pa.Table]:
"""Iterate over the result set, fetching arrow data in batches."""
yield from getattr(result, fetch_method)(batch_size)

@staticmethod
def _fetchall_rows(result: Cursor) -> Iterable[Sequence[Any]]:
"""Fetch row data in a single call, returning the complete result set."""
Expand All @@ -198,15 +188,18 @@ def _fetchmany_rows(
rows = result.fetchmany(batch_size)
if not rows:
break
elif not isinstance(rows[0], (list, tuple)):
for row in rows:
yield tuple(row)
elif isinstance(rows[0], (list, tuple)):
yield rows
else:
yield from rows
yield [tuple(row) for row in rows]

def _from_arrow(
self, batch_size: int | None, schema_overrides: SchemaDict | None
) -> DataFrame | None:
self,
*,
batch_size: int | None,
iter_batches: bool,
schema_overrides: SchemaDict | None,
) -> DataFrame | Iterable[DataFrame] | None:
"""Return resultset data in Arrow format for frame init."""
from polars import from_arrow

Expand All @@ -215,14 +208,17 @@ def _from_arrow(
if re.match(f"^{driver}$", self.driver_name):
size = batch_size if driver_properties["exact_batch_size"] else None
fetch_batches = driver_properties["fetch_batches"]
return from_arrow( # type: ignore[return-value]
data=(
self._fetch_arrow(self.result, fetch_batches, size)
if batch_size and fetch_batches is not None
else getattr(self.result, driver_properties["fetch_all"])()
),
schema_overrides=schema_overrides,
frames = (
from_arrow(batch, schema_overrides=schema_overrides)
for batch in (
getattr(self.result, fetch_batches)(size)
if (iter_batches and fetch_batches is not None)
else [
getattr(self.result, driver_properties["fetch_all"])()
]
)
)
return frames if iter_batches else next(frames) # type: ignore[arg-type,return-value]
except Exception as err:
# eg: valid turbodbc/snowflake connection, but no arrow support
# available in the underlying driver or this connection
Expand All @@ -236,8 +232,12 @@ def _from_arrow(
return None

def _from_rows(
self, batch_size: int | None, schema_overrides: SchemaDict | None
) -> DataFrame | None:
self,
*,
batch_size: int | None,
iter_batches: bool,
schema_overrides: SchemaDict | None,
) -> DataFrame | Iterable[DataFrame] | None:
"""Return resultset data row-wise for frame init."""
from polars import DataFrame

Expand All @@ -248,16 +248,20 @@ def _from_rows(
else self.result.description
)
column_names = [desc[0] for desc in description]
return DataFrame(
data=(
self._fetchall_rows(self.result)
if not batch_size
else self._fetchmany_rows(self.result, batch_size)
),
schema=column_names,
schema_overrides=schema_overrides,
orient="row",
frames = (
DataFrame(
data=rows,
schema=column_names,
schema_overrides=schema_overrides,
orient="row",
)
for rows in (
self._fetchmany_rows(self.result, batch_size)
if iter_batches
else [self._fetchall_rows(self.result)] # type: ignore[list-item]
)
)
return frames if iter_batches else next(frames) # type: ignore[arg-type]
return None

def execute(
Expand Down Expand Up @@ -286,9 +290,13 @@ def execute(
self.result = result
return self

def to_frame(
self, batch_size: int | None = None, schema_overrides: SchemaDict | None = None
) -> DataFrame:
def to_polars(
self,
*,
iter_batches: bool = False,
batch_size: int | None = None,
schema_overrides: SchemaDict | None = None,
) -> DataFrame | Iterable[DataFrame]:
"""
Convert the result set to a DataFrame.
Expand All @@ -297,12 +305,20 @@ def to_frame(
"""
if self.result is None:
raise RuntimeError("Cannot return a frame before executing a query")
elif iter_batches and not batch_size:
raise ValueError(
"Cannot set `iter_batches` without also setting a non-zero `batch_size`"
)

for frame_init in (
self._from_arrow, # init from arrow-native data (most efficient option)
self._from_rows, # row-wise fallback covering sqlalchemy, dbapi2, pyodbc
self._from_arrow, # init from arrow-native data (where support exists)
self._from_rows, # row-wise fallback (sqlalchemy, dbapi2, pyodbc, etc)
):
frame = frame_init(batch_size=batch_size, schema_overrides=schema_overrides)
frame = frame_init(
batch_size=batch_size,
iter_batches=iter_batches,
schema_overrides=schema_overrides,
)
if frame is not None:
return frame

Expand All @@ -311,16 +327,42 @@ def to_frame(
)


@deprecate_renamed_parameter("connection_uri", "connection", version="0.18.9")
def read_database( # noqa D417
@overload
def read_database(
query: str | Selectable,
connection: ConnectionOrCursor | str,
*,
iter_batches: Literal[False] = False,
batch_size: int | None = ...,
schema_overrides: SchemaDict | None = ...,
**kwargs: Any,
) -> DataFrame:
...


@overload
def read_database(
query: str | Selectable,
connection: ConnectionOrCursor | str,
*,
iter_batches: Literal[True],
batch_size: int | None = ...,
schema_overrides: SchemaDict | None = ...,
**kwargs: Any,
) -> Iterable[DataFrame]:
...


def read_database( # noqa: D417
query: str | Selectable,
connection: ConnectionOrCursor | str,
*,
iter_batches: bool = False,
batch_size: int | None = None,
schema_overrides: SchemaDict | None = None,
execute_options: dict[str, Any] | None = None,
**kwargs: Any,
) -> DataFrame:
) -> DataFrame | Iterable[DataFrame]:
"""
Read the results of a SQL query into a DataFrame, given a connection object.
Expand All @@ -334,21 +376,27 @@ def read_database( # noqa D417
executed against. Can also pass a valid ODBC connection string, starting with
"Driver=", in which case the ``arrow-odbc`` package will be used to establish
the connection and return Arrow-native data to Polars.
batch_size
Enable batched data fetching (internally) instead of collecting all rows at
once; this can be helpful for minimising the peak memory used for very large
resultsets. Note that this parameter is *not* equivalent to a "limit"; you
will always load all rows. If supported by the backend, this value is passed
iter_batches
Return an iterator of DataFrames, where each DataFrame represents a batch of
data returned by the query; this can be useful for processing large resultsets
in a memory-efficient manner. If supported by the backend, this value is passed
to the underlying query execution method (note that very low values will
typically result in poor performance as it will result in many round-trips to
the database as the data is returned). If the backend does not support changing
the batch size, this parameter is ignored without error.
the batch size then a single DataFrame is yielded from the iterator.
batch_size
Indicate the size of each batch when ``iter_batches`` is True (note that you can
still set this when ``iter_batches`` is False, in which case the resulting
DataFrame is constructed internally using batched return before being returned
to you. Note that some backends may support batched operation but not allow for
an explicit size; in this case you will still receive batches, but their exact
size will be determined by the backend (so may not equal the value set here).
schema_overrides
A dictionary mapping column names to dtypes, used to override the schema
inferred from the query cursor or given by the incoming Arrow data (depending
on driver/backend). This can be useful if the given types can be more precisely
defined (for example, if you know that a given column can be declared as `u32`
instead of `i64`).
instead of ``i64``).
execute_options
These options will be passed through into the underlying query execution method
as kwargs. In the case of connections made using an ODBC string (which use
Expand Down Expand Up @@ -403,13 +451,17 @@ def read_database( # noqa D417
... ) # doctest: +SKIP
Instantiate a DataFrame using an ODBC connection string (requires ``arrow-odbc``)
and set upper limits on the buffer size of variadic text/binary columns:
setting upper limits on the buffer size of variadic text/binary columns, returning
the result as an iterator over DataFrames containing batches of 1000 rows:
>>> df = pl.read_database(
>>> for df in pl.read_database(
... query="SELECT * FROM test_data",
... connection="Driver={PostgreSQL};Server=localhost;Port=5432;Database=test;Uid=usr;Pwd=",
... execute_options={"max_text_size": 512, "max_binary_size": 1024},
... ) # doctest: +SKIP
... iter_batches=True,
... batch_size=1000,
... ):
... do_something(df) # doctest: +SKIP
""" # noqa: W505
if isinstance(connection, str):
Expand All @@ -427,9 +479,13 @@ def read_database( # noqa D417
else:
# otherwise looks like a call to read_database_uri
issue_deprecation_warning(
message="Use of a string URI with 'read_database' is deprecated; use 'read_database_uri' instead",
message="Use of a string URI with 'read_database' is deprecated; use `read_database_uri` instead",
version="0.19.0",
)
if iter_batches or batch_size:
raise InvalidOperationError(
"Batch parameters are not supported for `read_database_uri`"
)
if not isinstance(query, (list, str)):
raise TypeError(
f"`read_database_uri` expects one or more string queries; found {type(query)}"
Expand All @@ -453,8 +509,9 @@ def read_database( # noqa D417
return cx.execute(
query=query,
options=execute_options,
).to_frame(
).to_polars(
batch_size=batch_size,
iter_batches=iter_batches,
schema_overrides=schema_overrides,
)

Expand Down Expand Up @@ -608,15 +665,21 @@ def _read_sql_connectorx(
"\n\nPlease run: pip install connectorx>=0.3.2"
) from None

tbl = cx.read_sql(
conn=connection_uri,
query=query,
return_type="arrow2",
partition_on=partition_on,
partition_range=partition_range,
partition_num=partition_num,
protocol=protocol,
)
try:
tbl = cx.read_sql(
conn=connection_uri,
query=query,
return_type="arrow2",
partition_on=partition_on,
partition_range=partition_range,
partition_num=partition_num,
protocol=protocol,
)
except BaseException as err:
# basic sanitisation of /user:pass/ credentials exposed in connectorx errs
errmsg = re.sub("://[^:]+:[^:]+@", "://***:***@", str(err))
raise type(err)(errmsg) from err

return from_arrow(tbl, schema_overrides=schema_overrides) # type: ignore[return-value]


Expand Down
6 changes: 4 additions & 2 deletions py-polars/polars/utils/_construction.py
Expand Up @@ -626,7 +626,9 @@ def _handle_columns_arg(
data[i].rename(c)
return data
else:
raise ValueError("dimensions of columns arg must match data dimensions")
raise ValueError(
f"dimensions of columns arg ({len(columns)}) must match data dimensions ({len(data)})"
)


def _post_apply_columns(
Expand Down Expand Up @@ -1022,7 +1024,7 @@ def _sequence_of_sequence_to_pydf(
local_schema_override = (
include_unknowns(schema_overrides, column_names) if schema_overrides else {}
)
if column_names and len(first_element) != len(column_names):
if column_names and first_element and len(first_element) != len(column_names):
raise ShapeError("the row data does not match the number of columns")

unpack_nested = False
Expand Down

0 comments on commit 56a7817

Please sign in to comment.