Skip to content

Commit

Permalink
rename Connection -> ConnectionUrl
Browse files Browse the repository at this point in the history
Signed-off-by: ZhengYu, Xu <zen-xu@outlook.com>
  • Loading branch information
zen-xu committed Apr 21, 2024
1 parent 0783580 commit 1e71570
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 54 deletions.
42 changes: 21 additions & 21 deletions connectorx-python/connectorx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@


def rewrite_conn(
conn: str | Connection, protocol: Protocol | None = None
conn: str | ConnectionUrl, protocol: Protocol | None = None
) -> tuple[str, Protocol]:
if not protocol:
# note: redshift/clickhouse are not compatible with the 'binary' protocol, and use other database
Expand All @@ -63,7 +63,7 @@ def rewrite_conn(


def get_meta(
conn: str | Connection,
conn: str | ConnectionUrl,
query: str,
protocol: Protocol | None = None,
) -> pd.DataFrame:
Expand All @@ -88,7 +88,7 @@ def get_meta(


def partition_sql(
conn: str | Connection,
conn: str | ConnectionUrl,
query: str,
partition_on: str,
partition_num: int,
Expand Down Expand Up @@ -122,7 +122,7 @@ def partition_sql(

def read_sql_pandas(
sql: list[str] | str,
con: str | Connection | dict[str, str] | dict[str, Connection],
con: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl],
index_col: str | None = None,
protocol: Protocol | None = None,
partition_on: str | None = None,
Expand Down Expand Up @@ -163,7 +163,7 @@ def read_sql_pandas(
# default return pd.DataFrame
@overload
def read_sql(
conn: str | Connection | dict[str, str] | dict[str, Connection],
conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl],
query: list[str] | str,
*,
protocol: Protocol | None = None,
Expand All @@ -176,7 +176,7 @@ def read_sql(

@overload
def read_sql(
conn: str | Connection | dict[str, str] | dict[str, Connection],
conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl],
query: list[str] | str,
*,
return_type: Literal["pandas"],
Expand All @@ -190,7 +190,7 @@ def read_sql(

@overload
def read_sql(
conn: str | Connection | dict[str, str] | dict[str, Connection],
conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl],
query: list[str] | str,
*,
return_type: Literal["arrow", "arrow2"],
Expand All @@ -204,7 +204,7 @@ def read_sql(

@overload
def read_sql(
conn: str | Connection | dict[str, str] | dict[str, Connection],
conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl],
query: list[str] | str,
*,
return_type: Literal["modin"],
Expand All @@ -218,7 +218,7 @@ def read_sql(

@overload
def read_sql(
conn: str | Connection | dict[str, str] | dict[str, Connection],
conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl],
query: list[str] | str,
*,
return_type: Literal["dask"],
Expand All @@ -232,7 +232,7 @@ def read_sql(

@overload
def read_sql(
conn: str | Connection | dict[str, str] | dict[str, Connection],
conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl],
query: list[str] | str,
*,
return_type: Literal["polars", "polars2"],
Expand All @@ -245,7 +245,7 @@ def read_sql(


def read_sql(
conn: str | Connection | dict[str, str] | dict[str, Connection],
conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl],
query: list[str] | str,
*,
return_type: Literal[
Expand Down Expand Up @@ -500,16 +500,16 @@ def try_import_module(name: str):
)


class Connection(Generic[_BackendT], str):
class ConnectionUrl(Generic[_BackendT], str):
@overload
def __new__(
cls,
*,
backend: Literal["sqlite"],
db_path: str | Path,
) -> Connection[Literal["sqlite"]]:
) -> ConnectionUrl[Literal["sqlite"]]:
"""
Help to build sqlite connection string.
Help to build sqlite connection string url.
Parameters
==========
Expand All @@ -525,9 +525,9 @@ def __new__(
*,
backend: Literal["bigquery"],
db_path: str | Path,
) -> Connection[Literal["bigquery"]]:
) -> ConnectionUrl[Literal["bigquery"]]:
"""
Help to build BigQuery connection string.
Help to build BigQuery connection string url.
Parameters
==========
Expand All @@ -548,9 +548,9 @@ def __new__(
port: int,
database: str = "",
database_options: dict[str, str] | None = None,
) -> Connection[_ServerBackendT]:
) -> ConnectionUrl[_ServerBackendT]:
"""
Help to build server-side backend database connection string.
Help to build server-side backend database connection string url.
Parameters
==========
Expand All @@ -574,9 +574,9 @@ def __new__(
def __new__(
cls,
raw_connection: str,
) -> Connection:
) -> ConnectionUrl:
"""
Build connection from raw connection string
Build connection from raw connection string url
Parameters
==========
Expand All @@ -596,7 +596,7 @@ def __new__(
database: str = "",
database_options: dict[str, str] | None = None,
db_path: str | Path = "",
) -> Connection:
) -> ConnectionUrl:
if raw_connection is not None:
return super().__new__(cls, raw_connection)

Expand Down
6 changes: 3 additions & 3 deletions connectorx-python/connectorx/tests/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from pandas.testing import assert_frame_equal

from .. import read_sql, Connection
from .. import read_sql, ConnectionUrl


@pytest.fixture(scope="module") # type: ignore
Expand Down Expand Up @@ -310,5 +310,5 @@ def test_bigquery_types(bigquery_url: str) -> None:
not os.environ.get("BIGQUERY_URL"),
reason="Test bigquery only when `BIGQUERY_URL` is set",
)
def test_connection(bigquery_url: str) -> None:
test_bigquery_types(Connection(bigquery_url))
def test_connection_url(bigquery_url: str) -> None:
test_bigquery_types(ConnectionUrl(bigquery_url))
10 changes: 6 additions & 4 deletions connectorx-python/connectorx/tests/test_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from pandas.testing import assert_frame_equal

from .. import read_sql, Connection
from .. import read_sql, ConnectionUrl


@pytest.fixture(scope="module") # type: ignore
Expand Down Expand Up @@ -76,7 +76,9 @@ def test_clickhouse_types(clickhouse_url: str) -> None:
dtype="datetime64[ns]",
),
"test_decimal": pd.Series(["2.22", "3.33", "4.44"], dtype="object"),
"test_varchar": pd.Series(["こんにちは", "Ha好ち😁ðy", "b"], dtype="object"),
"test_varchar": pd.Series(
["こんにちは", "Ha好ち😁ðy", "b"], dtype="object"
),
"test_char": pd.Series(["0123456789", "abcdefghij", "321"], dtype="object"),
},
)
Expand All @@ -87,5 +89,5 @@ def test_clickhouse_types(clickhouse_url: str) -> None:
not os.environ.get("CLICKHOUSE_URL"),
reason="Do not test Clickhouse unless `CLICKHOUSE_URL` is set",
)
def test_connection(clickhouse_url: str) -> None:
test_clickhouse_types(Connection(clickhouse_url))
def test_connection_url(clickhouse_url: str) -> None:
test_clickhouse_types(ConnectionUrl(clickhouse_url))
4 changes: 2 additions & 2 deletions connectorx-python/connectorx/tests/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,5 +498,5 @@ def test_mssql_offset(mssql_url: str) -> None:
assert_frame_equal(df, expected, check_names=True)


def test_connection(mssql_url: str) -> None:
test_mssql_offset(Connection(mssql_url))
def test_connection_url(mssql_url: str) -> None:
test_mssql_offset(Connection(mssql_url))
6 changes: 3 additions & 3 deletions connectorx-python/connectorx/tests/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from pandas.testing import assert_frame_equal

from .. import read_sql, Connection
from .. import read_sql, ConnectionUrl


@pytest.fixture(scope="module") # type: ignore
Expand Down Expand Up @@ -470,5 +470,5 @@ def test_mysql_cte(mysql_url: str) -> None:
assert_frame_equal(df, expected, check_names=True)


def test_connection(mysql_url: str) -> None:
test_mysql_cte(Connection(mysql_url))
def test_connection_url(mysql_url: str) -> None:
test_mysql_cte(ConnectionUrl(mysql_url))
15 changes: 8 additions & 7 deletions connectorx-python/connectorx/tests/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import pytest
from pandas.testing import assert_frame_equal

from .. import read_sql, Connection
from .. import read_sql, ConnectionUrl


@pytest.fixture(scope="module") # type: ignore
def oracle_url() -> str:
conn = os.environ["ORACLE_URL"]
return conn


@pytest.mark.xfail
@pytest.mark.skipif(
not os.environ.get("ORACLE_URL"), reason="Test oracle only when `ORACLE_URL` is set"
Expand Down Expand Up @@ -163,6 +164,7 @@ def test_oracle_manual_partition(oracle_url: str) -> None:
df.sort_values(by="TEST_INT", inplace=True, ignore_index=True)
assert_frame_equal(df, expected, check_names=True)


@pytest.mark.skipif(
not os.environ.get("ORACLE_URL"), reason="Test oracle only when `ORACLE_URL` is set"
)
Expand Down Expand Up @@ -353,11 +355,9 @@ def test_oracle_types(oracle_url: str) -> None:
],
dtype="datetime64[ns]",
),
"TEST_CLOB": pd.Series(
["13ab", "13ab", "13ab", None], dtype="object"
),
"TEST_CLOB": pd.Series(["13ab", "13ab", "13ab", None], dtype="object"),
"TEST_BLOB": pd.Series(
[ b'9\xaf', b'9\xaf', b'9\xaf', None], dtype="object"
[b"9\xaf", b"9\xaf", b"9\xaf", None], dtype="object"
),
}
)
Expand Down Expand Up @@ -429,6 +429,7 @@ def test_oracle_cte(oracle_url: str) -> None:
)
assert_frame_equal(df, expected, check_names=True)


@pytest.mark.skipif(
not os.environ.get("ORACLE_URL"), reason="Test oracle only when `ORACLE_URL` is set"
)
Expand All @@ -446,5 +447,5 @@ def test_oracle_round_function(oracle_url: str) -> None:
@pytest.mark.skipif(
not os.environ.get("ORACLE_URL"), reason="Test oracle only when `ORACLE_URL` is set"
)
def test_connection(oracle_url: str) -> None:
test_oracle_round_function(Connection(oracle_url))
def test_connection_url(oracle_url: str) -> None:
test_oracle_round_function(ConnectionUrl(oracle_url))
32 changes: 22 additions & 10 deletions connectorx-python/connectorx/tests/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from pandas.testing import assert_frame_equal

from .. import read_sql, Connection
from .. import read_sql, ConnectionUrl


@pytest.fixture(scope="module") # type: ignore
Expand All @@ -14,7 +14,10 @@ def redshift_url() -> str:
return conn


@pytest.mark.skipif(not os.environ.get("REDSHIFT_URL"), reason="Do not test Redshift unless `REDSHIFT_URL` is set")
@pytest.mark.skipif(
not os.environ.get("REDSHIFT_URL"),
reason="Do not test Redshift unless `REDSHIFT_URL` is set",
)
def test_redshift_without_partition(redshift_url: str) -> None:
query = "SELECT * FROM test_table"
df = read_sql(redshift_url, query, protocol="cursor")
Expand All @@ -37,7 +40,10 @@ def test_redshift_without_partition(redshift_url: str) -> None:
assert_frame_equal(df, expected, check_names=True)


@pytest.mark.skipif(not os.environ.get("REDSHIFT_URL"), reason="Do not test Redshift unless `REDSHIFT_URL` is set")
@pytest.mark.skipif(
not os.environ.get("REDSHIFT_URL"),
reason="Do not test Redshift unless `REDSHIFT_URL` is set",
)
def test_redshift_with_partition(redshift_url: str) -> None:
query = "SELECT * FROM test_table"
df = read_sql(
Expand All @@ -46,7 +52,7 @@ def test_redshift_with_partition(redshift_url: str) -> None:
partition_on="test_int",
partition_range=(0, 2000),
partition_num=3,
protocol="cursor"
protocol="cursor",
)
# result from redshift might have different order each time
df.sort_values(by="test_int", inplace=True, ignore_index=True)
Expand All @@ -67,7 +73,10 @@ def test_redshift_with_partition(redshift_url: str) -> None:
assert_frame_equal(df, expected, check_names=True)


@pytest.mark.skipif(not os.environ.get("REDSHIFT_URL"), reason="Do not test Redshift unless `REDSHIFT_URL` is set")
@pytest.mark.skipif(
not os.environ.get("REDSHIFT_URL"),
reason="Do not test Redshift unless `REDSHIFT_URL` is set",
)
def test_redshift_types(redshift_url: str) -> None:
query = "SELECT test_int16, test_char, test_time, test_datetime FROM test_types"
df = read_sql(redshift_url, query, protocol="cursor")
Expand All @@ -87,15 +96,18 @@ def test_redshift_types(redshift_url: str) -> None:
np.datetime64("2005-01-01T22:03:00"),
None,
np.datetime64("1987-01-01T11:00:00"),
], dtype="datetime64[ns]"
],
dtype="datetime64[ns]",
),

},
)
assert_frame_equal(df, expected, check_names=True)


@pytest.mark.skipif(not os.environ.get("REDSHIFT_URL"), reason="Do not test Redshift unless `REDSHIFT_URL` is set")
@pytest.mark.skipif(
not os.environ.get("REDSHIFT_URL"),
reason="Do not test Redshift unless `REDSHIFT_URL` is set",
)
def test_read_sql_on_utf8(redshift_url: str) -> None:
query = "SELECT * FROM test_str"
df = read_sql(redshift_url, query, protocol="cursor")
Expand Down Expand Up @@ -140,5 +152,5 @@ def test_read_sql_on_utf8(redshift_url: str) -> None:
not os.environ.get("REDSHIFT_URL"),
reason="Do not test Redshift unless `REDSHIFT_URL` is set",
)
def test_connection(redshift_url: str) -> None:
test_read_sql_on_utf8(Connection(redshift_url))
def test_connection_url(redshift_url: str) -> None:
test_read_sql_on_utf8(ConnectionUrl(redshift_url))
Loading

0 comments on commit 1e71570

Please sign in to comment.