Skip to content

Commit

Permalink
feat(python): Add read_database support for SurrealDB ("ws" and "…
Browse files Browse the repository at this point in the history
…http") (#15269)
  • Loading branch information
alexander-beedie committed Mar 27, 2024
1 parent 1a37f2b commit 1f3e1c4
Show file tree
Hide file tree
Showing 11 changed files with 546 additions and 289 deletions.
2 changes: 1 addition & 1 deletion py-polars/polars/dataframe/frame.py
Expand Up @@ -3475,7 +3475,7 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]:
)
raise ModuleNotFoundError(msg) from exc

from polars.io.database._uri import _open_adbc_connection
from polars.io.database._utils import _open_adbc_connection

if if_table_exists == "fail":
# if the table exists, 'create' will raise an error,
Expand Down
68 changes: 67 additions & 1 deletion py-polars/polars/io/database/_cursor_proxies.py
Expand Up @@ -2,9 +2,19 @@

from typing import TYPE_CHECKING, Any, Iterable

from polars.io.database._utils import _run_async

if TYPE_CHECKING:
import sys
from collections.abc import Coroutine

import pyarrow as pa

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self


class ODBCCursorProxy:
"""Cursor proxy for ODBC connections (requires `arrow-odbc`)."""
Expand All @@ -15,7 +25,8 @@ def __init__(self, connection_string: str) -> None:
self.query: str | None = None

def close(self) -> None:
"""Close the cursor (n/a: nothing to close)."""
"""Close the cursor."""
# n/a: nothing to close

def execute(self, query: str, **execute_options: Any) -> None:
"""Execute a query (n/a: just store query for the fetch* methods)."""
Expand Down Expand Up @@ -57,3 +68,58 @@ def fetch_record_batches(
# note: internally arrow-odbc always reads batches
fetchall = fetch_arrow_table
fetchmany = fetch_record_batches


class SurrealDBCursorProxy:
"""Cursor proxy for SurrealDB connections (requires `surrealdb`)."""

_cached_result: list[dict[str, Any]] | None = None

def __init__(self, client: Any) -> None:
self.client = client
self.execute_options: dict[str, Any] = {}
self.query: str = None # type: ignore[assignment]

@staticmethod
async def _unpack_result(
result: Coroutine[Any, Any, list[dict[str, Any]]],
) -> Coroutine[Any, Any, list[dict[str, Any]]]:
"""Unpack the async query result."""
response = (await result)[0]
if response["status"] != "OK":
raise RuntimeError(response["result"])
return response["result"]

def close(self) -> None:
"""Close the cursor."""
# no-op; never close a user's Surreal session

def execute(self, query: str, **execute_options: Any) -> Self:
"""Execute a query (n/a: just store query for the fetch* methods)."""
self._cached_result = None
self.execute_options = execute_options
self.query = query
return self

def fetchall(self) -> list[dict[str, Any]]:
"""Fetch all results (as a list of dictionaries)."""
return _run_async(
self._unpack_result(
result=self.client.query(
sql=self.query,
vars=(self.execute_options or None),
),
)
)

def fetchmany(self, size: int) -> list[dict[str, Any]]:
"""Fetch results in batches (simulated)."""
# first 'fetchmany' call acquires/caches the result
if self._cached_result is None:
self._cached_result = self.fetchall()

# return batches of the cached result; remove from the cache as
# we go, so as not to hold on to additional copies when done
result = self._cached_result[:size]
del self._cached_result[:size]
return result
60 changes: 31 additions & 29 deletions py-polars/polars/io/database/_executor.py
Expand Up @@ -6,6 +6,7 @@
from inspect import Parameter, isclass, signature
from typing import TYPE_CHECKING, Any, Iterable, Sequence

from polars import functions as F
from polars._utils.various import parse_version
from polars.convert import from_arrow
from polars.datatypes import (
Expand All @@ -19,11 +20,12 @@
from polars.datatypes.convert import _map_py_type_to_dtype
from polars.exceptions import ModuleUpgradeRequired, UnsuitableSQLError
from polars.io.database._arrow_registry import ARROW_DRIVER_REGISTRY
from polars.io.database._cursor_proxies import ODBCCursorProxy
from polars.io.database._cursor_proxies import ODBCCursorProxy, SurrealDBCursorProxy
from polars.io.database._inference import (
_infer_dtype_from_database_typename,
_integer_dtype_from_nbits,
)
from polars.io.database._utils import _run_async

if TYPE_CHECKING:
import sys
Expand Down Expand Up @@ -82,6 +84,9 @@ def __init__(self, connection: ConnectionOrCursor) -> None:
if isinstance(connection, ODBCCursorProxy)
else type(connection).__module__.split(".", 1)[0].lower()
)
if self.driver_name == "surrealdb":
connection = SurrealDBCursorProxy(client=connection)

self.cursor = self._normalise_cursor(connection)
self.result: Any = None

Expand All @@ -97,13 +102,25 @@ def __exit__(
# if we created it and are finished with it, we can
# close the cursor (but NOT the connection)
if type(self.cursor).__name__ == "AsyncConnection":
self._run_async(self._close_async_cursor())
_run_async(self._close_async_cursor())
elif self.can_close_cursor and hasattr(self.cursor, "close"):
self.cursor.close()

def __repr__(self) -> str:
return f"<{type(self).__name__} module={self.driver_name!r}>"

@staticmethod
def _apply_overrides(df: DataFrame, schema_overrides: SchemaDict) -> DataFrame:
"""Apply schema overrides to a DataFrame."""
existing_schema = df.schema
if cast_cols := [
F.col(col).cast(dtype)
for col, dtype in schema_overrides.items()
if col in existing_schema and dtype != existing_schema[col]
]:
df = df.with_columns(cast_cols)
return df

async def _close_async_cursor(self) -> None:
if self.can_close_cursor and hasattr(self.cursor, "close"):
from sqlalchemy.ext.asyncio.exc import AsyncContextNotStarted
Expand Down Expand Up @@ -156,7 +173,7 @@ def _fetchall_rows(result: Cursor) -> Iterable[Sequence[Any]]:
rows = result.fetchall()
return (
[tuple(row) for row in rows]
if rows and not isinstance(rows[0], (list, tuple))
if rows and not isinstance(rows[0], (list, tuple, dict))
else rows
)

Expand All @@ -168,7 +185,7 @@ def _fetchmany_rows(
rows = result.fetchmany(batch_size)
if not rows:
break
elif isinstance(rows[0], (list, tuple)):
elif isinstance(rows[0], (list, tuple, dict)):
yield rows
else:
yield [tuple(row) for row in rows]
Expand All @@ -192,7 +209,7 @@ def _from_arrow(
fetch_batches = driver_properties["fetch_batches"]
self.can_close_cursor = fetch_batches is None or not iter_batches
frames = (
batch
self._apply_overrides(batch, (schema_overrides or {}))
if isinstance(batch, DataFrame)
else from_arrow(batch, schema_overrides=schema_overrides)
for batch in self._fetch_arrow(
Expand Down Expand Up @@ -226,7 +243,7 @@ def _from_rows(
from polars import DataFrame

if is_async := isinstance(original_result := self.result, Coroutine):
self.result = self._run_async(self.result)
self.result = _run_async(self.result)
try:
if hasattr(self.result, "fetchall"):
if self.driver_name == "sqlalchemy":
Expand All @@ -239,8 +256,11 @@ def _from_rows(
else:
msg = f"Unable to determine metadata from query result; {self.result!r}"
raise ValueError(msg)
else:

elif hasattr(self.result, "description"):
cursor_desc = {d[0]: d[1:] for d in self.result.description}
else:
cursor_desc = {}

schema_overrides = self._inject_type_overrides(
description=cursor_desc,
Expand All @@ -250,7 +270,7 @@ def _from_rows(
frames = (
DataFrame(
data=rows,
schema=result_columns,
schema=result_columns or None,
schema_overrides=schema_overrides,
infer_schema_length=infer_schema_length,
orient="row",
Expand All @@ -267,24 +287,6 @@ def _from_rows(
if is_async:
original_result.close()

@staticmethod
def _run_async(co: Coroutine) -> Any: # type: ignore[type-arg]
"""Run asynchronous code as if it was synchronous."""
import asyncio

try:
import nest_asyncio

nest_asyncio.apply()
except ModuleNotFoundError as _err:
msg = (
"Executing using async drivers requires the `nest_asyncio` package."
"\n\nPlease run: pip install nest_asyncio"
)
raise ModuleNotFoundError(msg) from None

return asyncio.run(co)

@staticmethod
def _inject_type_overrides(
description: dict[str, Any],
Expand Down Expand Up @@ -348,8 +350,6 @@ def _normalise_cursor(self, conn: Any) -> Cursor:
"""Normalise a connection object such that we have the query executor."""
if self.driver_name == "sqlalchemy":
conn_type = type(conn).__name__
self.can_close_cursor = conn_type.endswith("Engine")

if conn_type in ("Session", "async_sessionmaker"):
return conn
else:
Expand All @@ -361,6 +361,8 @@ def _normalise_cursor(self, conn: Any) -> Cursor:
self.driver_name = "duckdb"
return conn.engine.raw_connection().driver_connection.c
elif conn_type in ("AsyncEngine", "Engine"):
# note: if we create it, we can close it
self.can_close_cursor = True
return conn.connect()
else:
return conn
Expand All @@ -375,7 +377,7 @@ def _normalise_cursor(self, conn: Any) -> Cursor:
# can execute directly (given cursor, sqlalchemy connection, etc)
return conn

msg = f"Unrecognised connection {conn!r}; unable to find 'execute' method"
msg = f"""Unrecognised connection type "{conn!r}"; no 'execute' or 'cursor' method"""
raise TypeError(msg)

async def _sqlalchemy_async_execute(self, query: TextClause, **options: Any) -> Any:
Expand Down
Expand Up @@ -8,6 +8,8 @@
from polars.convert import from_arrow

if TYPE_CHECKING:
from collections.abc import Coroutine

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
Expand All @@ -22,6 +24,30 @@
Selectable: TypeAlias = Any # type: ignore[no-redef]


def _run_async(co: Coroutine[Any, Any, Any]) -> Any:
"""Run asynchronous code as if it was synchronous."""
import asyncio

from polars._utils.unstable import issue_unstable_warning

issue_unstable_warning(
"Use of asynchronous connections is currently considered unstable"
" and unexpected issues may arise; if this happens, please report them."
)
try:
import nest_asyncio

nest_asyncio.apply()
except ModuleNotFoundError as _err:
msg = (
"Executing using async drivers requires the `nest_asyncio` package."
"\n\nPlease run: pip install nest_asyncio"
)
raise ModuleNotFoundError(msg) from None

return asyncio.run(co)


def _read_sql_connectorx(
query: str | list[str],
connection_uri: str,
Expand Down
53 changes: 41 additions & 12 deletions py-polars/polars/io/database/functions.py
Expand Up @@ -76,9 +76,14 @@ def read_database( # noqa: D417
connection
An instantiated connection (or cursor/client object) that the query can be
executed against. Can also pass a valid ODBC connection string, identified as
such if it contains the string "Driver=", in which case the `arrow-odbc`
such if it contains the string "Driver={...}", in which case the `arrow-odbc`
package will be used to establish the connection and return Arrow-native data
to Polars.
to Polars. Async driver connections are also supported, though this is currently
considered unstable.
.. warning::
Use of asynchronous connections is currently considered **unstable**, and
unexpected issues may arise; if this happens, please report them.
iter_batches
Return an iterator of DataFrames, where each DataFrame represents a batch of
data returned by the query; this can be useful for processing large resultsets
Expand Down Expand Up @@ -126,15 +131,14 @@ def read_database( # noqa: D417
include Dremio and InfluxDB).
* The `read_database_uri` function can be noticeably faster than `read_database`
if you are using a SQLAlchemy or DBAPI2 connection, as `connectorx` optimises
translation of the result set into Arrow format in Rust, whereas these libraries
will return row-wise data to Python *before* we can load into Arrow. Note that
you can determine the connection's URI from a SQLAlchemy engine object by calling
if you are using a SQLAlchemy or DBAPI2 connection, as `connectorx` and `adbc`
optimises translation of the result set into Arrow format. Note that you can
determine a connection's URI from a SQLAlchemy engine object by calling
`conn.engine.url.render_as_string(hide_password=False)`.
* If polars has to create a cursor from your connection in order to execute the
* If Polars has to create a cursor from your connection in order to execute the
query then that cursor will be automatically closed when the query completes;
however, polars will *never* close any other open connection or cursor.
however, Polars will *never* close any other open connection or cursor.
* We are able to support more than just relational databases and SQL queries
through this function. For example, we can load graph database results from
Expand Down Expand Up @@ -171,9 +175,9 @@ def read_database( # noqa: D417
... execute_options={"parameters": [0]},
... ) # doctest: +SKIP
Instantiate a DataFrame using an ODBC connection string (requires `arrow-odbc`)
setting upper limits on the buffer size of variadic text/binary columns, returning
the result as an iterator over DataFrames that each contain 1000 rows:
Instantiate a DataFrame using an ODBC connection string (requires the `arrow-odbc`
package) setting upper limits on the buffer size of variadic text/binary columns,
returning the result as an iterator over DataFrames that each contain 1000 rows:
>>> for df in pl.read_database(
... query="SELECT * FROM test_data",
Expand All @@ -191,6 +195,31 @@ def read_database( # noqa: D417
... connection=kuzu_db_conn,
... ) # doctest: +SKIP
Load data from an asynchronous SQLAlchemy driver/engine; note that asynchronous
connections and sessions are also supported here:
>>> from sqlalchemy.ext.asyncio import create_async_engine
>>> async_engine = create_async_engine("sqlite+aiosqlite:///test.db")
>>> df = pl.read_database(
... query="SELECT * FROM test_data",
... connection=async_engine,
... ) # doctest: +SKIP
Load data from an asynchronous SurrealDB client connection object; note that
both the WS (`Surreal`) and HTTP (`SurrealHTTP`) clients are supported:
>>> import asyncio
>>> async def surreal_query_to_frame(query: str, url: str):
... async with Surreal(url) as client:
... await client.use(namespace="test", database="test")
... return pl.read_database(query=query, connection=client)
>>> df = asyncio.run(
... surreal_query_to_frame(
... query="SELECT * FROM test_data",
... url="ws://localhost:8000/rpc",
... )
... ) # doctest: +SKIP
""" # noqa: W505
if isinstance(connection, str):
# check for odbc connection string
Expand Down Expand Up @@ -364,7 +393,7 @@ def read_database_uri(
... engine="adbc",
... ) # doctest: +SKIP
"""
from polars.io.database._uri import _read_sql_adbc, _read_sql_connectorx
from polars.io.database._utils import _read_sql_adbc, _read_sql_connectorx

if not isinstance(uri, str):
msg = f"expected connection to be a URI string; found {type(uri).__name__!r}"
Expand Down

0 comments on commit 1f3e1c4

Please sign in to comment.