Skip to content

Commit

Permalink
feat(python): Add support for async SQLAlchemy connections to `read…
Browse files Browse the repository at this point in the history
…_database` (#15162)
  • Loading branch information
alexander-beedie committed Mar 21, 2024
1 parent 97eff07 commit cb3efbc
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 80 deletions.
17 changes: 8 additions & 9 deletions py-polars/polars/datatypes/convert.py
Expand Up @@ -280,15 +280,14 @@ def _infer_dtype_from_database_typename(
return None # there's a timezone, but we don't know what it is
unit = _timeunit_from_precision(modifier) if modifier else "us"
dtype = Datetime(time_unit=(unit or "us")) # type: ignore[arg-type]

elif re.sub(r"\d", "", value) in ("INTERVAL", "TIMEDELTA"):
dtype = Duration

elif value in ("DATE", "DATE32", "DATE64"):
dtype = Date

elif value in ("TIME", "TIME32", "TIME64"):
dtype = Time
else:
value = re.sub(r"\d", "", value)
if value in ("INTERVAL", "TIMEDELTA"):
dtype = Duration
elif value == "DATE":
dtype = Date
elif value == "TIME":
dtype = Time

if not dtype and raise_unmatched:
msg = f"cannot infer dtype from {original_value!r} string value"
Expand Down
191 changes: 129 additions & 62 deletions py-polars/polars/io/database.py
Expand Up @@ -2,6 +2,8 @@

import re
import sys
import warnings
from collections.abc import Coroutine
from contextlib import suppress
from importlib import import_module
from inspect import Parameter, isclass, signature
Expand Down Expand Up @@ -47,6 +49,8 @@
except ImportError:
Selectable: TypeAlias = Any # type: ignore[no-redef]

from sqlalchemy.sql.elements import TextClause


class _ArrowDriverProperties_(TypedDict):
# name of the method that fetches all arrow data; tuple form
Expand Down Expand Up @@ -201,12 +205,21 @@ def __exit__(
) -> None:
# if we created it and are finished with it, we can
# close the cursor (but NOT the connection)
if self.can_close_cursor and hasattr(self.cursor, "close"):
if type(self.cursor).__name__ == "AsyncConnection":
self._run_async(self._close_async_cursor())
elif self.can_close_cursor and hasattr(self.cursor, "close"):
self.cursor.close()

def __repr__(self) -> str:
return f"<{type(self).__name__} module={self.driver_name!r}>"

async def _close_async_cursor(self) -> None:
if self.can_close_cursor and hasattr(self.cursor, "close"):
from sqlalchemy.ext.asyncio.exc import AsyncContextNotStarted

with suppress(AsyncContextNotStarted):
await self.cursor.close()

def _fetch_arrow(
self,
driver_properties: _ArrowDriverProperties_,
Expand Down Expand Up @@ -306,39 +319,61 @@ def _from_rows(
"""Return resultset data row-wise for frame init."""
from polars import DataFrame

if hasattr(self.result, "fetchall"):
if self.driver_name == "sqlalchemy":
if hasattr(self.result, "cursor"):
cursor_desc = {d[0]: d[1:] for d in self.result.cursor.description}
elif hasattr(self.result, "_metadata"):
cursor_desc = {k: None for k in self.result._metadata.keys}
if is_async := isinstance(original_result := self.result, Coroutine):
self.result = self._run_async(self.result)
try:
if hasattr(self.result, "fetchall"):
if self.driver_name == "sqlalchemy":
if hasattr(self.result, "cursor"):
cursor_desc = {
d[0]: d[1:] for d in self.result.cursor.description
}
elif hasattr(self.result, "_metadata"):
cursor_desc = {k: None for k in self.result._metadata.keys}
else:
msg = f"Unable to determine metadata from query result; {self.result!r}"
raise ValueError(msg)
else:
msg = f"Unable to determine metadata from query result; {self.result!r}"
raise ValueError(msg)
else:
cursor_desc = {d[0]: d[1:] for d in self.result.description}
cursor_desc = {d[0]: d[1:] for d in self.result.description}

schema_overrides = self._inject_type_overrides(
description=cursor_desc,
schema_overrides=(schema_overrides or {}),
)
result_columns = list(cursor_desc)
frames = (
DataFrame(
data=rows,
schema=result_columns,
schema_overrides=schema_overrides,
infer_schema_length=infer_schema_length,
orient="row",
schema_overrides = self._inject_type_overrides(
description=cursor_desc,
schema_overrides=(schema_overrides or {}),
)
for rows in (
list(self._fetchmany_rows(self.result, batch_size))
if iter_batches
else [self._fetchall_rows(self.result)] # type: ignore[list-item]
result_columns = list(cursor_desc)
frames = (
DataFrame(
data=rows,
schema=result_columns,
schema_overrides=schema_overrides,
infer_schema_length=infer_schema_length,
orient="row",
)
for rows in (
list(self._fetchmany_rows(self.result, batch_size))
if iter_batches
else [self._fetchall_rows(self.result)] # type: ignore[list-item]
)
)
)
return frames if iter_batches else next(frames) # type: ignore[arg-type]
return None
return frames if iter_batches else next(frames) # type: ignore[arg-type]
return None
finally:
if is_async:
original_result.close()

@staticmethod
def _run_async(co: Coroutine) -> Any: # type: ignore[type-arg]
"""Consolidate async event loop acquisition and coroutine/func execution."""
import asyncio

try:
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(co)

def _inject_type_overrides(
self,
Expand Down Expand Up @@ -398,8 +433,10 @@ def _inject_type_overrides(
def _normalise_cursor(self, conn: Any) -> Cursor:
"""Normalise a connection object such that we have the query executor."""
if self.driver_name == "sqlalchemy":
self.can_close_cursor = (conn_type := type(conn).__name__) == "Engine"
if conn_type == "Session":
conn_type = type(conn).__name__
self.can_close_cursor = conn_type.endswith("Engine")

if conn_type in ("Session", "async_sessionmaker"):
return conn
else:
# where possible, use the raw connection to access arrow integration
Expand All @@ -409,7 +446,7 @@ def _normalise_cursor(self, conn: Any) -> Cursor:
elif conn.engine.driver == "duckdb_engine":
self.driver_name = "duckdb"
return conn.engine.raw_connection().driver_connection.c
elif conn_type == "Engine":
elif conn_type in ("AsyncEngine", "Engine"):
return conn.connect()
else:
return conn
Expand All @@ -427,9 +464,63 @@ def _normalise_cursor(self, conn: Any) -> Cursor:
msg = f"Unrecognised connection {conn!r}; unable to find 'execute' method"
raise TypeError(msg)

async def _sqlalchemy_async_execute(self, query: TextClause, **options: Any) -> Any:
"""Execute a query using an async SQLAlchemy connection."""
is_session = type(self.cursor).__name__ == "async_sessionmaker"
cursor = self.cursor.begin() if is_session else self.cursor # type: ignore[attr-defined]
async with cursor as conn:
result = await conn.execute(query, **options)
return result

def _sqlalchemy_setup(
self, query: str | TextClause | Selectable, options: dict[str, Any]
) -> tuple[Any, dict[str, Any], str | TextClause | Selectable]:
"""Prepare a query for execution using a SQLAlchemy connection."""
from sqlalchemy.orm import Session
from sqlalchemy.sql import text
from sqlalchemy.sql.elements import TextClause

is_async = type(self.cursor).__name__ in (
"AsyncConnection",
"async_sessionmaker",
)
param_key = "parameters"
cursor_execute = None
if (
isinstance(self.cursor, Session)
and "parameters" in options
and "params" not in options
):
options = options.copy()
options["params"] = options.pop("parameters")
param_key = "params"

params = options.get(param_key)
if (
not is_async
and isinstance(params, Sequence)
and hasattr(self.cursor, "exec_driver_sql")
):
cursor_execute = self.cursor.exec_driver_sql
if isinstance(query, TextClause):
query = str(query)
if isinstance(params, list) and not all(
isinstance(p, (dict, tuple)) for p in params
):
options[param_key] = tuple(params)

elif isinstance(query, str):
query = text(query)

if cursor_execute is None:
cursor_execute = (
self._sqlalchemy_async_execute if is_async else self.cursor.execute
)
return cursor_execute, options, query

def execute(
self,
query: str | Selectable,
query: str | TextClause | Selectable,
*,
options: dict[str, Any] | None = None,
select_queries_only: bool = True,
Expand All @@ -442,42 +533,18 @@ def execute(
raise UnsuitableSQLError(msg)

options = options or {}
cursor_execute = self.cursor.execute

if self.driver_name == "sqlalchemy":
from sqlalchemy.orm import Session

param_key = "parameters"
if (
isinstance(self.cursor, Session)
and "parameters" in options
and "params" not in options
):
options = options.copy()
options["params"] = options.pop("parameters")
param_key = "params"

if isinstance(query, str):
params = options.get(param_key)
if isinstance(params, Sequence) and hasattr(
self.cursor, "exec_driver_sql"
):
cursor_execute = self.cursor.exec_driver_sql
if isinstance(params, list) and not all(
isinstance(p, (dict, tuple)) for p in params
):
options[param_key] = tuple(params)
else:
from sqlalchemy.sql import text

query = text(query) # type: ignore[assignment]
cursor_execute, options, query = self._sqlalchemy_setup(query, options)
else:
cursor_execute = self.cursor.execute

# note: some cursor execute methods (eg: sqlite3) only take positional
# params, hence the slightly convoluted resolution of the 'options' dict
try:
params = signature(cursor_execute).parameters
except ValueError:
params = {}
params = {} # type: ignore[assignment]

if not options or any(
p.kind in (Parameter.KEYWORD_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
Expand Down
18 changes: 9 additions & 9 deletions py-polars/polars/lazyframe/frame.py
Expand Up @@ -1995,12 +1995,12 @@ def collect_async(
This functionality is considered **unstable**. It may be changed
at any point without it being considered a breaking change.
Collects into a DataFrame (like :func:`collect`), but instead of returning
DataFrame directly, they are scheduled to be collected inside thread pool,
Collects into a DataFrame (like :func:`collect`) but, instead of returning
a DataFrame directly, it is scheduled to be collected inside a thread pool,
while this method returns almost instantly.
May be useful if you use gevent or asyncio and want to release control to other
greenlets/tasks while LazyFrames are being collected.
This can be useful if you use `gevent` or `asyncio` and want to release
control to other greenlets/tasks while LazyFrames are being collected.
Parameters
----------
Expand Down Expand Up @@ -2032,20 +2032,20 @@ def collect_async(
at any point without it being considered a breaking change.
.. note::
Use :func:`explain` to see if Polars can process the query in streaming
mode.
Use :func:`explain` to see if Polars can process the query in
streaming mode.
Returns
-------
If `gevent=False` (default) then returns awaitable.
If `gevent=False` (default) then returns an awaitable.
If `gevent=True` then returns wrapper that has
If `gevent=True` then returns wrapper that has a
`.get(block=True, timeout=None)` method.
See Also
--------
polars.collect_all : Collect multiple LazyFrames at the same time.
polars.collect_all_async: Collect multiple LazyFrames at the same time lazily.
polars.collect_all_async : Collect multiple LazyFrames at the same time lazily.
Notes
-----
Expand Down
1 change: 1 addition & 0 deletions py-polars/requirements-dev.txt
Expand Up @@ -26,6 +26,7 @@ tzdata; platform_system == 'Windows'
SQLAlchemy
adbc_driver_manager; python_version >= '3.9' and platform_system != 'Windows'
adbc_driver_sqlite; python_version >= '3.9' and platform_system != 'Windows'
aiosqlite
# TODO: Remove version constraint for connectorx when Python 3.12 is supported:
# https://github.com/sfu-db/connector-x/issues/527
connectorx; python_version <= '3.11'
Expand Down
36 changes: 36 additions & 0 deletions py-polars/tests/unit/io/test_database_read.py
Expand Up @@ -12,6 +12,7 @@
import pyarrow as pa
import pytest
from sqlalchemy import Integer, MetaData, Table, create_engine, func, select
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql.expression import cast as alchemy_cast

Expand Down Expand Up @@ -889,3 +890,38 @@ def test_database_dtype_inference_from_invalid_string(value: str) -> None:
raise_unmatched=False,
)
assert inferred_dtype is None


def test_read_database_async(tmp_sqlite_db: Path) -> None:
# confirm that we can load frame data from the core sqlalchemy async
# primitives: AsyncConnection, AsyncEngine, and async_sessionmaker

async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}")
async_connection = async_engine.connect()
async_session = async_sessionmaker(async_engine)

expected_frame = pl.DataFrame(
{"id": [2, 1], "name": ["other", "misc"], "value": [-99.5, 100.0]}
)
async_conn: Any
for async_conn in (
async_engine,
async_connection,
async_session,
):
if async_conn is async_session:
constraint, execute_opts = "", {}
else:
constraint = "WHERE value > :n"
execute_opts = {"parameters": {"n": -1000}}

df = pl.read_database(
query=f"""
SELECT id, name, value
FROM test_data {constraint}
ORDER BY id DESC
""",
connection=async_conn,
execute_options=execute_opts,
)
assert_frame_equal(expected_frame, df)

0 comments on commit cb3efbc

Please sign in to comment.