Skip to content

Commit

Permalink
feat(python): add ODBC connection string support to read_database (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Oct 2, 2023
1 parent 3d2f3e5 commit 9026165
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 16 deletions.
90 changes: 74 additions & 16 deletions py-polars/polars/io/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ class _DriverProperties_(TypedDict):
"fetch_batches": None,
"exact_batch_size": None,
},
"arrow_odbc_proxy": {
"fetch_all": "fetchall",
"fetch_batches": "fetchmany",
"exact_batch_size": True,
},
"databricks": {
"fetch_all": "fetchall_arrow",
"fetch_batches": "fetchmany_arrow",
Expand Down Expand Up @@ -84,6 +89,35 @@ class _DriverProperties_(TypedDict):
}


class ODBCCursorProxy:
"""Cursor proxy for ODBC connections (requires `arrow-odbc`)."""

def __init__(self, connection_string: str) -> None:
self.connection_string = connection_string
self.query: str | None = None

def close(self) -> None:
"""Close the cursor (n/a: nothing to close)."""

def execute(self, query: str) -> None:
"""Execute a query (n/a: just store query for the fetch* methods)."""
self.query = query

def fetchmany(
self, batch_size: int = 10_000
) -> Iterable[pa.RecordBatch | pa.Table]:
"""Fetch results in batches."""
from arrow_odbc import read_arrow_batches_from_odbc

yield from read_arrow_batches_from_odbc(
query=self.query,
batch_size=batch_size,
connection_string=self.connection_string,
)

fetchall = fetchmany


class ConnectionExecutor:
"""Abstraction for querying databases with user-supplied connection objects."""

Expand All @@ -93,7 +127,11 @@ class ConnectionExecutor:
acquired_cursor: bool = False

def __init__(self, connection: ConnectionOrCursor) -> None:
self.driver_name = type(connection).__module__.split(".", 1)[0].lower()
self.driver_name = (
"arrow_odbc_proxy"
if isinstance(connection, ODBCCursorProxy)
else type(connection).__module__.split(".", 1)[0].lower()
)
self.cursor = self._normalise_cursor(connection)
self.result: Any = None

Expand Down Expand Up @@ -270,7 +308,7 @@ def to_frame(
@deprecate_renamed_parameter("connection_uri", "connection", version="0.18.9")
def read_database( # noqa: D417
query: str | Selectable,
connection: ConnectionOrCursor,
connection: ConnectionOrCursor | str,
*,
batch_size: int | None = None,
schema_overrides: SchemaDict | None = None,
Expand All @@ -286,7 +324,8 @@ def read_database( # noqa: D417
be a suitable "Selectable", otherwise it is expected to be a string).
connection
An instantiated connection (or cursor/client object) that the query can be
executed against.
executed against. Can also pass a valid ODBC connection string here, if you
have installed the ``arrow-odbc`` driver/package.
batch_size
Enable batched data fetching (internally) instead of collecting all rows at
once; this can be helpful for minimising the peak memory used for very large
Expand Down Expand Up @@ -342,22 +381,41 @@ def read_database( # noqa: D417
... schema_overrides={"normalised_score": pl.UInt8},
... ) # doctest: +SKIP
"""
Instantiate a DataFrame using an ODBC connection string (requires ``arrow-odbc``):
>>> df = pl.read_database(
... query="SELECT * FROM test_data",
... connection="Driver={PostgreSQL};Server=localhost;Port=5432;Database=test;Uid=usr;Pwd=",
... ) # doctest: +SKIP
""" # noqa: W505
if isinstance(connection, str):
issue_deprecation_warning(
message="Use of a string URI with 'read_database' is deprecated; use 'read_database_uri' instead",
version="0.19.0",
)
if not isinstance(query, (list, str)):
raise TypeError(
f"`read_database_uri` expects one or more string queries; found {type(query)}"
if re.sub(r"\s", "", connection[:20]).lower().startswith("driver="):
try:
import arrow_odbc # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"use of an ODBC connection string requires the `arrow-odbc` package."
"\n\nPlease run `pip install arrow-odbc`."
) from None

connection = ODBCCursorProxy(connection)
else:
issue_deprecation_warning(
message="Use of a string URI with 'read_database' is deprecated; use 'read_database_uri' instead",
version="0.19.0",
)
return read_database_uri(
query, uri=connection, schema_overrides=schema_overrides, **kwargs
)
elif kwargs:
if not isinstance(query, (list, str)):
raise TypeError(
f"`read_database_uri` expects one or more string queries; found {type(query)}"
)
return read_database_uri(
query, uri=connection, schema_overrides=schema_overrides, **kwargs
)

if kwargs:
raise ValueError(
f"`read_database` **kwargs only exist for deprecating string URIs: found {kwargs!r}"
f"`read_database` **kwargs only exist for passthrough to `read_database_uri`: found {kwargs!r}"
)

with ConnectionExecutor(connection) as cx:
Expand Down
1 change: 1 addition & 0 deletions py-polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ module = [
"IPython.*",
"adbc_driver_postgresql.*",
"adbc_driver_sqlite.*",
"arrow_odbc",
"backports",
"connectorx",
"deltalake.*",
Expand Down

0 comments on commit 9026165

Please sign in to comment.