Skip to content

Commit

Permalink
Merge pull request #613 from zen-xu/connector
Browse files Browse the repository at this point in the history
feat: Supports structured construction of database connection.
  • Loading branch information
wangxiaoying authored Apr 22, 2024
2 parents bba7a4e + 349d742 commit a9e73f2
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 35 deletions.
161 changes: 147 additions & 14 deletions connectorx-python/connectorx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import importlib
from importlib.metadata import version
import urllib.parse

from typing import Literal, TYPE_CHECKING, overload
from importlib.metadata import version
from pathlib import Path
from typing import Literal, TYPE_CHECKING, overload, Generic, TypeVar

from .connectorx import (
read_sql as _read_sql,
Expand All @@ -20,7 +22,7 @@
import pyarrow as pa

# only for typing hints
from .connectorx import _DataframeInfos, _ArrowInfos
from .connectorx import _DataframeInfos, _ArrowInfos


__version__ = version(__name__)
Expand All @@ -42,7 +44,12 @@
Protocol = Literal["csv", "binary", "cursor", "simple", "text"]


def rewrite_conn(conn: str, protocol: Protocol | None = None) -> tuple[str, Protocol]:
_BackendT = TypeVar("_BackendT")


def rewrite_conn(
conn: str | ConnectionUrl, protocol: Protocol | None = None
) -> tuple[str, Protocol]:
if not protocol:
# note: redshift/clickhouse are not compatible with the 'binary' protocol, and use other database
# drivers to connect. set a compatible protocol and masquerade as the appropriate backend.
Expand All @@ -59,7 +66,7 @@ def rewrite_conn(conn: str, protocol: Protocol | None = None) -> tuple[str, Prot


def get_meta(
conn: str,
conn: str | ConnectionUrl,
query: str,
protocol: Protocol | None = None,
) -> pd.DataFrame:
Expand All @@ -84,7 +91,7 @@ def get_meta(


def partition_sql(
conn: str,
conn: str | ConnectionUrl,
query: str,
partition_on: str,
partition_num: int,
Expand Down Expand Up @@ -118,7 +125,7 @@ def partition_sql(

def read_sql_pandas(
sql: list[str] | str,
con: str | dict[str, str],
con: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl],
index_col: str | None = None,
protocol: Protocol | None = None,
partition_on: str | None = None,
Expand Down Expand Up @@ -159,7 +166,7 @@ def read_sql_pandas(
# default return pd.DataFrame
@overload
def read_sql(
conn: str | dict[str, str],
conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl],
query: list[str] | str,
*,
protocol: Protocol | None = None,
Expand All @@ -172,7 +179,7 @@ def read_sql(

@overload
def read_sql(
conn: str | dict[str, str],
conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl],
query: list[str] | str,
*,
return_type: Literal["pandas"],
Expand All @@ -186,7 +193,7 @@ def read_sql(

@overload
def read_sql(
conn: str | dict[str, str],
conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl],
query: list[str] | str,
*,
return_type: Literal["arrow", "arrow2"],
Expand All @@ -200,7 +207,7 @@ def read_sql(

@overload
def read_sql(
conn: str | dict[str, str],
conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl],
query: list[str] | str,
*,
return_type: Literal["modin"],
Expand All @@ -214,7 +221,7 @@ def read_sql(

@overload
def read_sql(
conn: str | dict[str, str],
conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl],
query: list[str] | str,
*,
return_type: Literal["dask"],
Expand All @@ -228,7 +235,7 @@ def read_sql(

@overload
def read_sql(
conn: str | dict[str, str],
conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl],
query: list[str] | str,
*,
return_type: Literal["polars", "polars2"],
Expand All @@ -241,7 +248,7 @@ def read_sql(


def read_sql(
conn: str | dict[str, str],
conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl],
query: list[str] | str,
*,
return_type: Literal[
Expand Down Expand Up @@ -477,3 +484,129 @@ def try_import_module(name: str):
return importlib.import_module(name)
except ModuleNotFoundError:
raise ValueError(f"You need to install {name.split('.')[0]} first")


_ServerBackendT = TypeVar(
"_ServerBackendT",
bound=Literal[
"redshift",
"clickhouse",
"postgres",
"postgresql",
"mysql",
"mssql",
"oracle",
"duckdb",
],
)


class ConnectionUrl(Generic[_BackendT], str):
@overload
def __new__(
cls,
*,
backend: Literal["sqlite"],
db_path: str | Path,
) -> ConnectionUrl[Literal["sqlite"]]:
"""
Help to build sqlite connection string url.
Parameters
==========
backend:
must specify "sqlite".
db_path:
the path to the sqlite database file.
"""

@overload
def __new__(
cls,
*,
backend: Literal["bigquery"],
db_path: str | Path,
) -> ConnectionUrl[Literal["bigquery"]]:
"""
Help to build BigQuery connection string url.
Parameters
==========
backend:
must specify "bigquery".
db_path:
the path to the bigquery database file.
"""

@overload
def __new__(
cls,
*,
backend: _ServerBackendT,
username: str,
password: str = "",
server: str,
port: int,
database: str = "",
database_options: dict[str, str] | None = None,
) -> ConnectionUrl[_ServerBackendT]:
"""
Help to build server-side backend database connection string url.
Parameters
==========
backend:
the database backend.
username:
the database username.
password:
the database password.
server:
the database server name.
port:
the database server port.
database:
the database name.
database_options:
the database options for connection.
"""

@overload
def __new__(
cls,
raw_connection: str,
) -> ConnectionUrl:
"""
Build connection from raw connection string url
Parameters
==========
raw_connection:
raw connection string
"""

def __new__(
cls,
raw_connection: str | None = None,
*,
backend: str = "",
username: str = "",
password: str = "",
server: str = "",
port: int | None = None,
database: str = "",
database_options: dict[str, str] | None = None,
db_path: str | Path = "",
) -> ConnectionUrl:
if raw_connection is not None:
return super().__new__(cls, raw_connection)

assert backend
if backend == "sqlite":
db_path = urllib.parse.quote(str(db_path))
connection = f"{backend}://{db_path}"
else:
connection = f"{backend}://{username}:{password}@{server}:{port}/{database}"
if database_options:
connection += "?" + urllib.parse.urlencode(database_options)
return super().__new__(cls, connection)
28 changes: 15 additions & 13 deletions connectorx-python/connectorx/tests/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from pandas.testing import assert_frame_equal

from .. import read_sql
from .. import read_sql, ConnectionUrl


@pytest.fixture(scope="module") # type: ignore
Expand Down Expand Up @@ -121,9 +121,7 @@ def test_bigquery_some_empty_partition(bigquery_url: str) -> None:
index=range(1),
data={
"test_int": pd.Series([1], dtype="Int64"),
"test_string": pd.Series(
["str1"], dtype="object"
),
"test_string": pd.Series(["str1"], dtype="object"),
"test_float": pd.Series([1.10], dtype="float64"),
"test_bool": pd.Series([True], dtype="boolean"),
},
Expand All @@ -137,10 +135,7 @@ def test_bigquery_some_empty_partition(bigquery_url: str) -> None:
)
def test_bigquery_join(bigquery_url: str) -> None:
query = "SELECT T.test_int, T.test_string, S.test_str FROM `dataprep-bigquery.dataprep.test_table` T INNER JOIN `dataprep-bigquery.dataprep.test_types` S ON T.test_int = S.test_int"
df = read_sql(
bigquery_url,
query
)
df = read_sql(bigquery_url, query)
df = df.sort_values("test_int").reset_index(drop=True)
expected = pd.DataFrame(
index=range(2),
Expand All @@ -151,14 +146,14 @@ def test_bigquery_join(bigquery_url: str) -> None:
"str1",
"str2",
],
dtype="object"
dtype="object",
),
"test_str": pd.Series(
[
"😁😂😜",
"こんにちはЗдра́в",
],
dtype="object"
dtype="object",
),
},
)
Expand Down Expand Up @@ -188,22 +183,21 @@ def test_bigquery_join_with_partition(bigquery_url: str) -> None:
"str1",
"str2",
],
dtype="object"
dtype="object",
),
"test_str": pd.Series(
[
"😁😂😜",
"こんにちはЗдра́в",
],
dtype="object"
dtype="object",
),
},
)
df.sort_values(by="test_int", inplace=True, ignore_index=True)
assert_frame_equal(df, expected, check_names=True)



@pytest.mark.skipif(
not os.environ.get("BIGQUERY_URL"),
reason="Test bigquery only when `BIGQUERY_URL` is set",
Expand Down Expand Up @@ -310,3 +304,11 @@ def test_bigquery_types(bigquery_url: str) -> None:
},
)
assert_frame_equal(df, expected, check_names=True)


@pytest.mark.skipif(
not os.environ.get("BIGQUERY_URL"),
reason="Test bigquery only when `BIGQUERY_URL` is set",
)
def test_connection_url(bigquery_url: str) -> None:
test_bigquery_types(ConnectionUrl(bigquery_url))
10 changes: 9 additions & 1 deletion connectorx-python/connectorx/tests/test_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from pandas.testing import assert_frame_equal

from .. import read_sql
from .. import read_sql, ConnectionUrl


@pytest.fixture(scope="module") # type: ignore
Expand Down Expand Up @@ -81,3 +81,11 @@ def test_clickhouse_types(clickhouse_url: str) -> None:
},
)
assert_frame_equal(df, expected, check_names=True)


@pytest.mark.skipif(
not os.environ.get("CLICKHOUSE_URL"),
reason="Do not test Clickhouse unless `CLICKHOUSE_URL` is set",
)
def test_connection_url(clickhouse_url: str) -> None:
test_clickhouse_types(ConnectionUrl(clickhouse_url))
6 changes: 5 additions & 1 deletion connectorx-python/connectorx/tests/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd
import pytest
from pandas.testing import assert_frame_equal
from connectorx import ConnectionUrl

from .. import read_sql

Expand Down Expand Up @@ -92,7 +93,6 @@ def test_mssql_udf(mssql_url: str) -> None:


def test_manual_partition(mssql_url: str) -> None:

queries = [
"SELECT * FROM test_table WHERE test_int < 2",
"SELECT * FROM test_table WHERE test_int >= 2",
Expand Down Expand Up @@ -496,3 +496,7 @@ def test_mssql_offset(mssql_url: str) -> None:
}
)
assert_frame_equal(df, expected, check_names=True)


def test_connection_url(mssql_url: str) -> None:
test_mssql_offset(ConnectionUrl(mssql_url))
Loading

0 comments on commit a9e73f2

Please sign in to comment.