From 90261657bda82667f69736dbba7861f639decadf Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Mon, 2 Oct 2023 17:35:39 +0400 Subject: [PATCH] feat(python): add ODBC connection string support to `read_database` (#11448) --- py-polars/polars/io/database.py | 90 +++++++++++++++++++++++++++------ py-polars/pyproject.toml | 1 + 2 files changed, 75 insertions(+), 16 deletions(-) diff --git a/py-polars/polars/io/database.py b/py-polars/polars/io/database.py index a36297922a65..e04eecb91a6b 100644 --- a/py-polars/polars/io/database.py +++ b/py-polars/polars/io/database.py @@ -47,6 +47,11 @@ class _DriverProperties_(TypedDict): "fetch_batches": None, "exact_batch_size": None, }, + "arrow_odbc_proxy": { + "fetch_all": "fetchall", + "fetch_batches": "fetchmany", + "exact_batch_size": True, + }, "databricks": { "fetch_all": "fetchall_arrow", "fetch_batches": "fetchmany_arrow", @@ -84,6 +89,35 @@ class _DriverProperties_(TypedDict): } +class ODBCCursorProxy: + """Cursor proxy for ODBC connections (requires `arrow-odbc`).""" + + def __init__(self, connection_string: str) -> None: + self.connection_string = connection_string + self.query: str | None = None + + def close(self) -> None: + """Close the cursor (n/a: nothing to close).""" + + def execute(self, query: str) -> None: + """Execute a query (n/a: just store query for the fetch* methods).""" + self.query = query + + def fetchmany( + self, batch_size: int = 10_000 + ) -> Iterable[pa.RecordBatch | pa.Table]: + """Fetch results in batches.""" + from arrow_odbc import read_arrow_batches_from_odbc + + yield from read_arrow_batches_from_odbc( + query=self.query, + batch_size=batch_size, + connection_string=self.connection_string, + ) + + fetchall = fetchmany + + class ConnectionExecutor: """Abstraction for querying databases with user-supplied connection objects.""" @@ -93,7 +127,11 @@ class ConnectionExecutor: acquired_cursor: bool = False def __init__(self, connection: ConnectionOrCursor) -> None: - self.driver_name = type(connection).__module__.split(".", 1)[0].lower() + self.driver_name = ( + "arrow_odbc_proxy" + if isinstance(connection, ODBCCursorProxy) + else type(connection).__module__.split(".", 1)[0].lower() + ) self.cursor = self._normalise_cursor(connection) self.result: Any = None @@ -270,7 +308,7 @@ def to_frame( @deprecate_renamed_parameter("connection_uri", "connection", version="0.18.9") def read_database( # noqa: D417 query: str | Selectable, - connection: ConnectionOrCursor, + connection: ConnectionOrCursor | str, *, batch_size: int | None = None, schema_overrides: SchemaDict | None = None, @@ -286,7 +324,8 @@ def read_database( # noqa: D417 be a suitable "Selectable", otherwise it is expected to be a string). connection An instantiated connection (or cursor/client object) that the query can be - executed against. + executed against. Can also pass a valid ODBC connection string here, if you + have installed the ``arrow-odbc`` driver/package. batch_size Enable batched data fetching (internally) instead of collecting all rows at once; this can be helpful for minimising the peak memory used for very large @@ -342,22 +381,41 @@ def read_database( # noqa: D417 ... schema_overrides={"normalised_score": pl.UInt8}, ... ) # doctest: +SKIP - """ + Instantiate a DataFrame using an ODBC connection string (requires ``arrow-odbc``): + + >>> df = pl.read_database( + ... query="SELECT * FROM test_data", + ... connection="Driver={PostgreSQL};Server=localhost;Port=5432;Database=test;Uid=usr;Pwd=", + ... ) # doctest: +SKIP + + """ # noqa: W505 if isinstance(connection, str): - issue_deprecation_warning( - message="Use of a string URI with 'read_database' is deprecated; use 'read_database_uri' instead", - version="0.19.0", - ) - if not isinstance(query, (list, str)): - raise TypeError( - f"`read_database_uri` expects one or more string queries; found {type(query)}" + if re.sub(r"\s", "", connection[:20]).lower().startswith("driver="): + try: + import arrow_odbc # noqa: F401 + except ModuleNotFoundError: + raise ModuleNotFoundError( + "use of an ODBC connection string requires the `arrow-odbc` package." + "\n\nPlease run `pip install arrow-odbc`." + ) from None + + connection = ODBCCursorProxy(connection) + else: + issue_deprecation_warning( + message="Use of a string URI with 'read_database' is deprecated; use 'read_database_uri' instead", + version="0.19.0", ) - return read_database_uri( - query, uri=connection, schema_overrides=schema_overrides, **kwargs - ) - elif kwargs: + if not isinstance(query, (list, str)): + raise TypeError( + f"`read_database_uri` expects one or more string queries; found {type(query)}" + ) + return read_database_uri( + query, uri=connection, schema_overrides=schema_overrides, **kwargs + ) + + if kwargs: raise ValueError( - f"`read_database` **kwargs only exist for deprecating string URIs: found {kwargs!r}" + f"`read_database` **kwargs only exist for passthrough to `read_database_uri`: found {kwargs!r}" ) with ConnectionExecutor(connection) as cx: diff --git a/py-polars/pyproject.toml b/py-polars/pyproject.toml index 062d4de4ffea..21ce228ce519 100644 --- a/py-polars/pyproject.toml +++ b/py-polars/pyproject.toml @@ -77,6 +77,7 @@ module = [ "IPython.*", "adbc_driver_postgresql.*", "adbc_driver_sqlite.*", + "arrow_odbc", "backports", "connectorx", "deltalake.*",