Skip to content

Commit

Permalink
feat(rust,python,cli): add SQL engine support for POSITION and `STR…
Browse files Browse the repository at this point in the history
…POS` (#13585)
  • Loading branch information
alexander-beedie committed Jan 10, 2024
1 parent fef0273 commit f8762fb
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 2 deletions.
11 changes: 11 additions & 0 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,12 @@ pub(crate) enum PolarsSQLFunctions {
/// SELECT column_2 from df WHERE STARTS_WITH(column_1, 'a');
/// ```
StartsWith,
/// SQL 'strpos' function
/// Returns the index of the given substring in the target string.
/// ```sql
/// SELECT STRPOS(column_1,'xyz') from df;
/// ```
StrPos,
/// SQL 'substr' function
/// Returns a portion of the data (first character = 0) in the range.
/// \[start, start + length]
Expand Down Expand Up @@ -673,6 +679,7 @@ impl PolarsSQLFunctions {
"lower" => Self::Lower,
"ltrim" => Self::LTrim,
"octet_length" => Self::OctetLength,
"strpos" => Self::StrPos,
"regexp_like" => Self::RegexpLike,
"replace" => Self::Replace,
"reverse" => Self::Reverse,
Expand Down Expand Up @@ -844,6 +851,10 @@ impl SQLFunctionVisitor<'_> {
_ => polars_bail!(InvalidOperation: "Invalid number of arguments for LTrim: {}", function.args.len()),
},
OctetLength => self.visit_unary(|e| e.str().len_bytes()),
StrPos => {
// note: 1-indexed, not 0-indexed, and returns zero if match not found
self.visit_binary(|expr, substring| (expr.str().find(substring, true) + lit(1u32)).fill_null(0u32))
},
RegexpLike => match function.args.len() {
2 => self.visit_binary(|e, s| e.str().contains(s, true)),
3 => self.try_visit_ternary(|e, pat, flags| {
Expand Down
12 changes: 10 additions & 2 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@ impl SQLExprVisitor<'_> {
escape_char,
} => self.visit_like(*negated, expr, pattern, escape_char, true),
SQLExpr::Nested(expr) => self.visit_expr(expr),
SQLExpr::Position { expr, r#in } => Ok(
// note: SQL is 1-indexed, not 0-indexed
(self
.visit_expr(r#in)?
.str()
.find(self.visit_expr(expr)?, true)
+ lit(1u32))
.fill_null(0u32),
),
SQLExpr::RLike {
// note: parses both RLIKE and REGEXP
negated,
Expand Down Expand Up @@ -178,8 +187,8 @@ impl SQLExprVisitor<'_> {
}

let mut lf = self.ctx.execute_query_no_ctes(subquery)?;

let schema = lf.schema()?;

if restriction == SubqueryRestriction::SingleColumn {
if schema.len() != 1 {
polars_bail!(InvalidOperation: "SQL subquery will return more than one column");
Expand All @@ -194,7 +203,6 @@ impl SQLExprVisitor<'_> {
if let Some((old_name, _)) = schema_entry {
let new_name = String::from(old_name.as_str()) + rand_string.as_str();
lf = lf.rename([old_name.to_string()], [new_name.clone()]);

return Ok(Expr::SubPlan(
SpecialEq::new(Arc::new(lf.logical_plan)),
vec![new_name],
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expressions/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The following methods are available under the `expr.str` attribute.
Expr.str.extract
Expr.str.extract_all
Expr.str.extract_groups
Expr.str.find
Expr.str.json_decode
Expr.str.json_extract
Expr.str.json_path_match
Expand Down
57 changes: 57 additions & 0 deletions py-polars/tests/unit/sql/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import polars as pl
from polars.exceptions import ComputeError, InvalidOperationError
from polars.testing import assert_frame_equal


# TODO: Do not rely on I/O for these tests
Expand Down Expand Up @@ -174,6 +175,62 @@ def test_string_like(pattern: str, like: str, expected: list[int]) -> None:
assert res == expected


def test_string_position() -> None:
df = pl.Series(
name="city",
values=["Dubai", "Abu Dhabi", "Sharjah", "Al Ain", "Ajman", "Ras Al Khaimah"],
).to_frame()

with pl.SQLContext(cities=df, eager_execution=True) as ctx:
res = ctx.execute(
"""
SELECT
POSITION('a' IN city) AS a_lc1,
POSITION('A' IN city) AS a_uc1,
STRPOS(city,'a') AS a_lc2,
STRPOS(city,'A') AS a_uc2,
FROM cities
"""
)
expected_lc = [4, 7, 3, 0, 4, 2]
expected_uc = [0, 1, 0, 1, 1, 5]

assert res.to_dict(as_series=False) == {
"a_lc1": expected_lc,
"a_uc1": expected_uc,
"a_lc2": expected_lc,
"a_uc2": expected_uc,
}

df = pl.DataFrame({"txt": ["AbCdEXz", "XyzFDkE"]})
with pl.SQLContext(txt=df) as ctx:
res = ctx.execute(
"""
SELECT
txt,
POSITION('E' IN txt) AS match_E,
STRPOS(txt,'X') AS match_X
FROM txt
""",
eager=True,
)
assert_frame_equal(
res,
pl.DataFrame(
data={
"txt": ["AbCdEXz", "XyzFDkE"],
"match_E": [5, 7],
"match_X": [6, 1],
},
schema={
"txt": pl.String,
"match_E": pl.UInt32,
"match_X": pl.UInt32,
},
),
)


def test_string_replace() -> None:
df = pl.DataFrame({"words": ["Yemeni coffee is the best coffee", "", None]})
with pl.SQLContext(df=df) as ctx:
Expand Down

0 comments on commit f8762fb

Please sign in to comment.