From cb3efbc7df8ed6225024519b86c356d9bda816a7 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Thu, 21 Mar 2024 12:25:03 +0400 Subject: [PATCH] feat(python): Add support for `async` SQLAlchemy connections to `read_database` (#15162) --- py-polars/polars/datatypes/convert.py | 17 +- py-polars/polars/io/database.py | 191 ++++++++++++------ py-polars/polars/lazyframe/frame.py | 18 +- py-polars/requirements-dev.txt | 1 + py-polars/tests/unit/io/test_database_read.py | 36 ++++ 5 files changed, 183 insertions(+), 80 deletions(-) diff --git a/py-polars/polars/datatypes/convert.py b/py-polars/polars/datatypes/convert.py index 6630069575ba..fa3835007989 100644 --- a/py-polars/polars/datatypes/convert.py +++ b/py-polars/polars/datatypes/convert.py @@ -280,15 +280,14 @@ def _infer_dtype_from_database_typename( return None # there's a timezone, but we don't know what it is unit = _timeunit_from_precision(modifier) if modifier else "us" dtype = Datetime(time_unit=(unit or "us")) # type: ignore[arg-type] - - elif re.sub(r"\d", "", value) in ("INTERVAL", "TIMEDELTA"): - dtype = Duration - - elif value in ("DATE", "DATE32", "DATE64"): - dtype = Date - - elif value in ("TIME", "TIME32", "TIME64"): - dtype = Time + else: + value = re.sub(r"\d", "", value) + if value in ("INTERVAL", "TIMEDELTA"): + dtype = Duration + elif value == "DATE": + dtype = Date + elif value == "TIME": + dtype = Time if not dtype and raise_unmatched: msg = f"cannot infer dtype from {original_value!r} string value" diff --git a/py-polars/polars/io/database.py b/py-polars/polars/io/database.py index abb432c88e3c..672f76ef35be 100644 --- a/py-polars/polars/io/database.py +++ b/py-polars/polars/io/database.py @@ -2,6 +2,8 @@ import re import sys +import warnings +from collections.abc import Coroutine from contextlib import suppress from importlib import import_module from inspect import Parameter, isclass, signature @@ -47,6 +49,8 @@ except ImportError: Selectable: TypeAlias = Any # type: ignore[no-redef] + from sqlalchemy.sql.elements import TextClause + class _ArrowDriverProperties_(TypedDict): # name of the method that fetches all arrow data; tuple form @@ -201,12 +205,21 @@ def __exit__( ) -> None: # if we created it and are finished with it, we can # close the cursor (but NOT the connection) - if self.can_close_cursor and hasattr(self.cursor, "close"): + if type(self.cursor).__name__ == "AsyncConnection": + self._run_async(self._close_async_cursor()) + elif self.can_close_cursor and hasattr(self.cursor, "close"): self.cursor.close() def __repr__(self) -> str: return f"<{type(self).__name__} module={self.driver_name!r}>" + async def _close_async_cursor(self) -> None: + if self.can_close_cursor and hasattr(self.cursor, "close"): + from sqlalchemy.ext.asyncio.exc import AsyncContextNotStarted + + with suppress(AsyncContextNotStarted): + await self.cursor.close() + def _fetch_arrow( self, driver_properties: _ArrowDriverProperties_, @@ -306,39 +319,61 @@ def _from_rows( """Return resultset data row-wise for frame init.""" from polars import DataFrame - if hasattr(self.result, "fetchall"): - if self.driver_name == "sqlalchemy": - if hasattr(self.result, "cursor"): - cursor_desc = {d[0]: d[1:] for d in self.result.cursor.description} - elif hasattr(self.result, "_metadata"): - cursor_desc = {k: None for k in self.result._metadata.keys} + if is_async := isinstance(original_result := self.result, Coroutine): + self.result = self._run_async(self.result) + try: + if hasattr(self.result, "fetchall"): + if self.driver_name == "sqlalchemy": + if hasattr(self.result, "cursor"): + cursor_desc = { + d[0]: d[1:] for d in self.result.cursor.description + } + elif hasattr(self.result, "_metadata"): + cursor_desc = {k: None for k in self.result._metadata.keys} + else: + msg = f"Unable to determine metadata from query result; {self.result!r}" + raise ValueError(msg) else: - msg = f"Unable to determine metadata from query result; {self.result!r}" - raise ValueError(msg) - else: - cursor_desc = {d[0]: d[1:] for d in self.result.description} + cursor_desc = {d[0]: d[1:] for d in self.result.description} - schema_overrides = self._inject_type_overrides( - description=cursor_desc, - schema_overrides=(schema_overrides or {}), - ) - result_columns = list(cursor_desc) - frames = ( - DataFrame( - data=rows, - schema=result_columns, - schema_overrides=schema_overrides, - infer_schema_length=infer_schema_length, - orient="row", + schema_overrides = self._inject_type_overrides( + description=cursor_desc, + schema_overrides=(schema_overrides or {}), ) - for rows in ( - list(self._fetchmany_rows(self.result, batch_size)) - if iter_batches - else [self._fetchall_rows(self.result)] # type: ignore[list-item] + result_columns = list(cursor_desc) + frames = ( + DataFrame( + data=rows, + schema=result_columns, + schema_overrides=schema_overrides, + infer_schema_length=infer_schema_length, + orient="row", + ) + for rows in ( + list(self._fetchmany_rows(self.result, batch_size)) + if iter_batches + else [self._fetchall_rows(self.result)] # type: ignore[list-item] + ) ) - ) - return frames if iter_batches else next(frames) # type: ignore[arg-type] - return None + return frames if iter_batches else next(frames) # type: ignore[arg-type] + return None + finally: + if is_async: + original_result.close() + + @staticmethod + def _run_async(co: Coroutine) -> Any: # type: ignore[type-arg] + """Consolidate async event loop acquisition and coroutine/func execution.""" + import asyncio + + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop.run_until_complete(co) def _inject_type_overrides( self, @@ -398,8 +433,10 @@ def _inject_type_overrides( def _normalise_cursor(self, conn: Any) -> Cursor: """Normalise a connection object such that we have the query executor.""" if self.driver_name == "sqlalchemy": - self.can_close_cursor = (conn_type := type(conn).__name__) == "Engine" - if conn_type == "Session": + conn_type = type(conn).__name__ + self.can_close_cursor = conn_type.endswith("Engine") + + if conn_type in ("Session", "async_sessionmaker"): return conn else: # where possible, use the raw connection to access arrow integration @@ -409,7 +446,7 @@ def _normalise_cursor(self, conn: Any) -> Cursor: elif conn.engine.driver == "duckdb_engine": self.driver_name = "duckdb" return conn.engine.raw_connection().driver_connection.c - elif conn_type == "Engine": + elif conn_type in ("AsyncEngine", "Engine"): return conn.connect() else: return conn @@ -427,9 +464,63 @@ def _normalise_cursor(self, conn: Any) -> Cursor: msg = f"Unrecognised connection {conn!r}; unable to find 'execute' method" raise TypeError(msg) + async def _sqlalchemy_async_execute(self, query: TextClause, **options: Any) -> Any: + """Execute a query using an async SQLAlchemy connection.""" + is_session = type(self.cursor).__name__ == "async_sessionmaker" + cursor = self.cursor.begin() if is_session else self.cursor # type: ignore[attr-defined] + async with cursor as conn: + result = await conn.execute(query, **options) + return result + + def _sqlalchemy_setup( + self, query: str | TextClause | Selectable, options: dict[str, Any] + ) -> tuple[Any, dict[str, Any], str | TextClause | Selectable]: + """Prepare a query for execution using a SQLAlchemy connection.""" + from sqlalchemy.orm import Session + from sqlalchemy.sql import text + from sqlalchemy.sql.elements import TextClause + + is_async = type(self.cursor).__name__ in ( + "AsyncConnection", + "async_sessionmaker", + ) + param_key = "parameters" + cursor_execute = None + if ( + isinstance(self.cursor, Session) + and "parameters" in options + and "params" not in options + ): + options = options.copy() + options["params"] = options.pop("parameters") + param_key = "params" + + params = options.get(param_key) + if ( + not is_async + and isinstance(params, Sequence) + and hasattr(self.cursor, "exec_driver_sql") + ): + cursor_execute = self.cursor.exec_driver_sql + if isinstance(query, TextClause): + query = str(query) + if isinstance(params, list) and not all( + isinstance(p, (dict, tuple)) for p in params + ): + options[param_key] = tuple(params) + + elif isinstance(query, str): + query = text(query) + + if cursor_execute is None: + cursor_execute = ( + self._sqlalchemy_async_execute if is_async else self.cursor.execute + ) + return cursor_execute, options, query + def execute( self, - query: str | Selectable, + query: str | TextClause | Selectable, *, options: dict[str, Any] | None = None, select_queries_only: bool = True, @@ -442,42 +533,18 @@ def execute( raise UnsuitableSQLError(msg) options = options or {} - cursor_execute = self.cursor.execute if self.driver_name == "sqlalchemy": - from sqlalchemy.orm import Session - - param_key = "parameters" - if ( - isinstance(self.cursor, Session) - and "parameters" in options - and "params" not in options - ): - options = options.copy() - options["params"] = options.pop("parameters") - param_key = "params" - - if isinstance(query, str): - params = options.get(param_key) - if isinstance(params, Sequence) and hasattr( - self.cursor, "exec_driver_sql" - ): - cursor_execute = self.cursor.exec_driver_sql - if isinstance(params, list) and not all( - isinstance(p, (dict, tuple)) for p in params - ): - options[param_key] = tuple(params) - else: - from sqlalchemy.sql import text - - query = text(query) # type: ignore[assignment] + cursor_execute, options, query = self._sqlalchemy_setup(query, options) + else: + cursor_execute = self.cursor.execute # note: some cursor execute methods (eg: sqlite3) only take positional # params, hence the slightly convoluted resolution of the 'options' dict try: params = signature(cursor_execute).parameters except ValueError: - params = {} + params = {} # type: ignore[assignment] if not options or any( p.kind in (Parameter.KEYWORD_ONLY, Parameter.POSITIONAL_OR_KEYWORD) diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 8947e9618f20..4cc5b0c6130c 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -1995,12 +1995,12 @@ def collect_async( This functionality is considered **unstable**. It may be changed at any point without it being considered a breaking change. - Collects into a DataFrame (like :func:`collect`), but instead of returning - DataFrame directly, they are scheduled to be collected inside thread pool, + Collects into a DataFrame (like :func:`collect`) but, instead of returning + a DataFrame directly, it is scheduled to be collected inside a thread pool, while this method returns almost instantly. - May be useful if you use gevent or asyncio and want to release control to other - greenlets/tasks while LazyFrames are being collected. + This can be useful if you use `gevent` or `asyncio` and want to release + control to other greenlets/tasks while LazyFrames are being collected. Parameters ---------- @@ -2032,20 +2032,20 @@ def collect_async( at any point without it being considered a breaking change. .. note:: - Use :func:`explain` to see if Polars can process the query in streaming - mode. + Use :func:`explain` to see if Polars can process the query in + streaming mode. Returns ------- - If `gevent=False` (default) then returns awaitable. + If `gevent=False` (default) then returns an awaitable. - If `gevent=True` then returns wrapper that has + If `gevent=True` then returns wrapper that has a `.get(block=True, timeout=None)` method. See Also -------- polars.collect_all : Collect multiple LazyFrames at the same time. - polars.collect_all_async: Collect multiple LazyFrames at the same time lazily. + polars.collect_all_async : Collect multiple LazyFrames at the same time lazily. Notes ----- diff --git a/py-polars/requirements-dev.txt b/py-polars/requirements-dev.txt index abf2550ec9c1..c62c134ecfcc 100644 --- a/py-polars/requirements-dev.txt +++ b/py-polars/requirements-dev.txt @@ -26,6 +26,7 @@ tzdata; platform_system == 'Windows' SQLAlchemy adbc_driver_manager; python_version >= '3.9' and platform_system != 'Windows' adbc_driver_sqlite; python_version >= '3.9' and platform_system != 'Windows' +aiosqlite # TODO: Remove version constraint for connectorx when Python 3.12 is supported: # https://github.com/sfu-db/connector-x/issues/527 connectorx; python_version <= '3.11' diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index 509ca2b4d2b6..c4650bc1abc6 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -12,6 +12,7 @@ import pyarrow as pa import pytest from sqlalchemy import Integer, MetaData, Table, create_engine, func, select +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.sql.expression import cast as alchemy_cast @@ -889,3 +890,38 @@ def test_database_dtype_inference_from_invalid_string(value: str) -> None: raise_unmatched=False, ) assert inferred_dtype is None + + +def test_read_database_async(tmp_sqlite_db: Path) -> None: + # confirm that we can load frame data from the core sqlalchemy async + # primitives: AsyncConnection, AsyncEngine, and async_sessionmaker + + async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}") + async_connection = async_engine.connect() + async_session = async_sessionmaker(async_engine) + + expected_frame = pl.DataFrame( + {"id": [2, 1], "name": ["other", "misc"], "value": [-99.5, 100.0]} + ) + async_conn: Any + for async_conn in ( + async_engine, + async_connection, + async_session, + ): + if async_conn is async_session: + constraint, execute_opts = "", {} + else: + constraint = "WHERE value > :n" + execute_opts = {"parameters": {"n": -1000}} + + df = pl.read_database( + query=f""" + SELECT id, name, value + FROM test_data {constraint} + ORDER BY id DESC + """, + connection=async_conn, + execute_options=execute_opts, + ) + assert_frame_equal(expected_frame, df)