Skip to content

Commit

Permalink
expose connectorx 'protocol' param to "read_sql" so it can work with …
Browse files Browse the repository at this point in the history
…redshift (#2003)
  • Loading branch information
alexander-beedie committed Dec 6, 2021
1 parent 22e40df commit e4bd365
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 10 deletions.
23 changes: 13 additions & 10 deletions py-polars/polars/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,10 +735,11 @@ def read_sql(
partition_on: Optional[str] = None,
partition_range: Optional[Tuple[int, int]] = None,
partition_num: Optional[int] = None,
protocol: Optional[str] = None,
) -> DataFrame:
"""
Read a SQL query into a DataFrame
Make sure to install connextorx>=0.2
Read a SQL query into a DataFrame.
Make sure to install connectorx>=0.2
# Sources
Supports reading a sql query from the following data sources:
Expand All @@ -758,30 +759,31 @@ def read_sql(
Parameters
----------
sql
raw sql query
raw sql query.
connection_uri
connectorx connection uri:
- "postgresql://username:password@server:port/database"
partition_on
the column to partition the result.
the column on which to partition the result.
partition_range
the value range of the partition column.
partition_num
how many partition to generate.
how many partitions to generate.
protocol
backend-specific transfer protocol directive; see connectorx documentation for details.
Examples
--------
## Single threaded
Read a DataFrame from a SQL using a single thread:
Read a DataFrame from a SQL query using a single thread:
>>> uri = "postgresql://username:password@server:port/database"
>>> query = "SELECT * FROM lineitem"
>>> pl.read_sql(query, uri) # doctest: +SKIP
## Using 10 threads
Read a DataFrame parallelly using 10 threads by automatically partitioning the provided SQL on the partition column:
Read a DataFrame in parallel using 10 threads by automatically partitioning the provided SQL on the partition column:
>>> uri = "postgresql://username:password@server:port/database"
>>> query = "SELECT * FROM lineitem"
Expand All @@ -790,7 +792,7 @@ def read_sql(
... ) # doctest: +SKIP
## Using
Read a DataFrame parallel using 2 threads by manually providing two partition SQLs:
Read a DataFrame in parallel using 2 threads by explicitly providing two SQL queries:
>>> uri = "postgresql://username:password@server:port/database"
>>> queries = [
Expand All @@ -808,9 +810,10 @@ def read_sql(
partition_on=partition_on,
partition_range=partition_range,
partition_num=partition_num,
protocol=protocol,
)
return from_arrow(tbl) # type: ignore[return-value]
else:
raise ImportError(
"connectorx is not installed." "Please run pip install connectorx>=0.2.0a3"
"connectorx is not installed." "Please run pip install connectorx>=0.2.2"
)
53 changes: 53 additions & 0 deletions py-polars/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import copy
import gzip
import io
import os
import pickle
import zlib
from datetime import date
from functools import partial
from pathlib import Path
from typing import Dict, Type
Expand Down Expand Up @@ -406,3 +408,54 @@ def test_scan_csv() -> None:
def test_scan_parquet() -> None:
df = pl.scan_parquet(Path(__file__).parent / "files" / "small.parquet")
assert df.collect().shape == (4, 3)


def test_read_sql() -> None:
import sqlite3
import tempfile

try:
import connectorx # noqa

with tempfile.TemporaryDirectory() as tmpdir_name:
test_db = os.path.join(tmpdir_name, "test.db")
conn = sqlite3.connect(test_db)
conn.executescript(
"""
CREATE TABLE test_data (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
value FLOAT,
date DATE
);
INSERT INTO test_data(name,value,date) VALUES ('misc',100.0,'2020-01-01'), ('other',-99.5,'2021-12-31');
"""
)
conn.close()

df = pl.read_sql(
connection_uri=f"sqlite:///{test_db}", sql="SELECT * FROM test_data"
)
# ┌─────┬───────┬───────┬────────────┐
# │ id ┆ name ┆ value ┆ date │
# │ --- ┆ --- ┆ --- ┆ --- │
# │ i64 ┆ str ┆ f64 ┆ date │
# ╞═════╪═══════╪═══════╪════════════╡
# │ 1 ┆ misc ┆ 100.0 ┆ 2020-01-01 │
# ├╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤
# │ 2 ┆ other ┆ -99.5 ┆ 2021-12-31 │
# └─────┴───────┴───────┴────────────┘

expected = {
"id": pl.Int64,
"name": pl.Utf8,
"value": pl.Float64,
"date": pl.Date,
}
assert df.schema == expected
assert df.shape == (2, 4)
assert df["date"].to_list() == [date(2020, 1, 1), date(2021, 12, 31)]
# assert df.rows() == ...

except ImportError:
pass # if connectorx not installed on test machine

0 comments on commit e4bd365

Please sign in to comment.