Skip to content

Commit

Permalink
feat(rust,python,cli): add SQL engine support for RIGHT and `REVERS…
Browse files Browse the repository at this point in the history
…E` string functions (#13461)
  • Loading branch information
alexander-beedie committed Jan 5, 2024
1 parent 6f77f94 commit e66297a
Show file tree
Hide file tree
Showing 18 changed files with 1,704 additions and 1,555 deletions.
2 changes: 1 addition & 1 deletion crates/polars-sql/Cargo.toml
Expand Up @@ -12,7 +12,7 @@ description = "SQL transpiler for Polars. Converts SQL to Polars logical plans"
arrow = { workspace = true }
polars-core = { workspace = true }
polars-error = { workspace = true }
polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "cross_join", "cum_agg", "dtype-date", "is_in", "log", "meta", "regex", "round_series", "sign", "strings", "trigonometry"] }
polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "cross_join", "cum_agg", "dtype-date", "is_in", "log", "meta", "regex", "round_series", "sign", "string_reverse", "strings", "trigonometry"] }
polars-plan = { workspace = true }

hex = { workspace = true }
Expand Down
27 changes: 25 additions & 2 deletions crates/polars-sql/src/functions.rs
Expand Up @@ -268,7 +268,7 @@ pub(crate) enum PolarsSQLFunctions {
#[cfg(feature = "nightly")]
InitCap,
/// SQL 'left' function
/// Returns the `length` first characters
/// Returns the first (leftmost) `n` characters
/// ```sql
/// SELECT LEFT(column_1, 3) from df;
/// ```
Expand Down Expand Up @@ -309,6 +309,18 @@ pub(crate) enum PolarsSQLFunctions {
/// SELECT REPLACE(column_1,'old','new') from df;
/// ```
Replace,
/// SQL 'reverse' function
/// Return the reversed string.
/// ```sql
/// SELECT REVERSE(column_1) from df;
/// ```
Reverse,
/// SQL 'right' function
/// Returns the last (rightmost) `n` characters
/// ```sql
/// SELECT RIGHT(column_1, 3) from df;
/// ```
Right,
/// SQL 'rtrim' function
/// Strip whitespaces from the right
/// ```sql
Expand Down Expand Up @@ -637,6 +649,8 @@ impl PolarsSQLFunctions {
"octet_length" => Self::OctetLength,
"regexp_like" => Self::RegexpLike,
"replace" => Self::Replace,
"reverse" => Self::Reverse,
"right" => Self::Right,
"rtrim" => Self::RTrim,
"starts_with" => Self::StartsWith,
"substr" => Self::Substring,
Expand Down Expand Up @@ -809,7 +823,16 @@ impl SQLFunctionVisitor<'_> {
"Invalid number of arguments for Replace: {}",
function.args.len()
),
}
},
Reverse => self.visit_unary(|e| e.str().reverse()),
Right => self.try_visit_binary(|e, length| {
Ok(e.str().slice( match length {
Expr::Literal(LiteralValue::Int64(n)) => -n,
_ => {
polars_bail!(InvalidOperation: "Invalid 'length' for Right: {}", function.args[1]);
}
}, None))
}),
RTrim => match function.args.len() {
1 => self.visit_unary(|e| e.str().strip_chars_end(lit(Null))),
2 => self.visit_binary(|e, s| e.str().strip_chars_end(s)),
Expand Down
67 changes: 67 additions & 0 deletions py-polars/tests/unit/sql/test_cast.py
@@ -0,0 +1,67 @@
from __future__ import annotations

import pytest

import polars as pl
from polars.exceptions import ComputeError


def test_cast() -> None:
df = pl.DataFrame(
{
"a": [1, 2, 3, 4, 5],
"b": [1.1, 2.2, 3.3, 4.4, 5.5],
"c": ["a", "b", "c", "d", "e"],
"d": [True, False, True, False, True],
}
)
# test various dtype casts, using standard ("CAST <col> AS <dtype>")
# and postgres-specific ("<col>::<dtype>") cast syntax
with pl.SQLContext(df=df, eager_execution=True) as ctx:
res = ctx.execute(
"""
SELECT
-- float
CAST(a AS DOUBLE PRECISION) AS a_f64,
a::real AS a_f32,
-- integer
CAST(b AS TINYINT) AS b_i8,
CAST(b AS SMALLINT) AS b_i16,
b::bigint AS b_i64,
d::tinyint AS d_i8,
-- string/binary
CAST(a AS CHAR) AS a_char,
CAST(b AS VARCHAR) AS b_varchar,
c::blob AS c_blob,
c::bytes AS c_bytes,
c::VARBINARY AS c_varbinary,
CAST(d AS CHARACTER VARYING) AS d_charvar,
FROM df
"""
)
assert res.schema == {
"a_f64": pl.Float64,
"a_f32": pl.Float32,
"b_i8": pl.Int8,
"b_i16": pl.Int16,
"b_i64": pl.Int64,
"d_i8": pl.Int8,
"a_char": pl.String,
"b_varchar": pl.String,
"c_blob": pl.Binary,
"c_bytes": pl.Binary,
"c_varbinary": pl.Binary,
"d_charvar": pl.String,
}
assert res.rows() == [
(1.0, 1.0, 1, 1, 1, 1, "1", "1.1", b"a", b"a", b"a", "true"),
(2.0, 2.0, 2, 2, 2, 0, "2", "2.2", b"b", b"b", b"b", "false"),
(3.0, 3.0, 3, 3, 3, 1, "3", "3.3", b"c", b"c", b"c", "true"),
(4.0, 4.0, 4, 4, 4, 0, "4", "4.4", b"d", b"d", b"d", "false"),
(5.0, 5.0, 5, 5, 5, 1, "5", "5.5", b"e", b"e", b"e", "true"),
]

with pytest.raises(ComputeError, match="unsupported use of FORMAT in CAST"):
pl.SQLContext(df=df, eager_execution=True).execute(
"SELECT CAST(a AS STRING FORMAT 'HEX') FROM df"
)
70 changes: 70 additions & 0 deletions py-polars/tests/unit/sql/test_conditional.py
@@ -0,0 +1,70 @@
from __future__ import annotations

from pathlib import Path

import pytest

import polars as pl
from polars.exceptions import InvalidOperationError


@pytest.fixture()
def foods_ipc_path() -> Path:
return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc"


def test_case_when() -> None:
lf = pl.LazyFrame(
{
"v1": [None, 2, None, 4],
"v2": [101, 202, 303, 404],
}
)
with pl.SQLContext(test_data=lf, eager_execution=True) as ctx:
out = ctx.execute(
"""
SELECT *, CASE WHEN COALESCE(v1, v2) % 2 != 0 THEN 'odd' ELSE 'even' END as "v3"
FROM test_data
"""
)
assert out.to_dict(as_series=False) == {
"v1": [None, 2, None, 4],
"v2": [101, 202, 303, 404],
"v3": ["odd", "even", "odd", "even"],
}


def test_nullif_coalesce(foods_ipc_path: Path) -> None:
nums = pl.LazyFrame(
{
"x": [1, None, 2, 3, None, 4],
"y": [5, 4, None, 3, None, 2],
"z": [3, 4, None, 3, 6, None],
}
)
res = pl.SQLContext(df=nums).execute(
"""
SELECT
COALESCE(x,y,z) as "coalsc",
NULLIF(x, y) as "nullif x_y",
NULLIF(y, z) as "nullif y_z",
IFNULL(x, y) as "ifnull x_y",
IFNULL(y,-1) as "inullf y_z",
COALESCE(x, NULLIF(y,z)) as "both"
FROM df
""",
eager=True,
)

assert res.to_dict(as_series=False) == {
"coalsc": [1, 4, 2, 3, 6, 4],
"nullif x_y": [1, None, 2, None, None, 4],
"nullif y_z": [5, None, None, None, None, 2],
"ifnull x_y": [1, 4, 2, 3, None, 4],
"inullf y_z": [5, 4, -1, 3, -1, 2],
"both": [1, None, 2, 3, None, 4],
}
for null_func in ("IFNULL", "NULLIF"):
# both functions expect only 2 arguments
with pytest.raises(InvalidOperationError):
pl.SQLContext(df=nums).execute(f"SELECT {null_func}(x,y,z) FROM df")
37 changes: 37 additions & 0 deletions py-polars/tests/unit/sql/test_functions.py
@@ -0,0 +1,37 @@
from __future__ import annotations

from pathlib import Path

import pytest

import polars as pl
from polars.exceptions import InvalidOperationError
from polars.testing import assert_frame_equal


@pytest.fixture()
def foods_ipc_path() -> Path:
return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc"


def test_sql_expr() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": ["xyz", "abcde", None]})
sql_exprs = pl.sql_expr(
[
"MIN(a)",
"POWER(a,a) AS aa",
"SUBSTR(b,2,2) AS b2",
]
)
result = df.select(*sql_exprs)
expected = pl.DataFrame(
{"a": [1, 1, 1], "aa": [1.0, 4.0, 27.0], "b2": ["yz", "bc", None]}
)
assert_frame_equal(result, expected)

# expect expressions that can't reasonably be parsed as expressions to raise
# (for example: those that explicitly reference tables and/or use wildcards)
with pytest.raises(
InvalidOperationError, match=r"Unable to parse 'xyz\.\*' as Expr"
):
pl.sql_expr("xyz.*")
62 changes: 62 additions & 0 deletions py-polars/tests/unit/sql/test_group_by.py
@@ -0,0 +1,62 @@
from __future__ import annotations

from pathlib import Path

import pytest

import polars as pl


@pytest.fixture()
def foods_ipc_path() -> Path:
return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc"


def test_group_by(foods_ipc_path: Path) -> None:
lf = pl.scan_ipc(foods_ipc_path)

ctx = pl.SQLContext(eager_execution=True)
ctx.register("foods", lf)

out = ctx.execute(
"""
SELECT
category,
count(category) as n,
max(calories),
min(fats_g)
FROM foods
GROUP BY category
HAVING n > 5
ORDER BY n, category DESC
"""
)
assert out.to_dict(as_series=False) == {
"category": ["vegetables", "fruit", "seafood"],
"n": [7, 7, 8],
"calories": [45, 130, 200],
"fats_g": [0.0, 0.0, 1.5],
}

lf = pl.LazyFrame(
{
"grp": ["a", "b", "c", "c", "b"],
"att": ["x", "y", "x", "y", "y"],
}
)
assert ctx.tables() == ["foods"]

ctx.register("test", lf)
assert ctx.tables() == ["foods", "test"]

out = ctx.execute(
"""
SELECT
grp,
COUNT(DISTINCT att) AS n_dist_attr
FROM test
GROUP BY grp
HAVING n_dist_attr > 1
"""
)
assert out.to_dict(as_series=False) == {"grp": ["c"], "n_dist_attr": [2]}

0 comments on commit e66297a

Please sign in to comment.