diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 1834ae4f608d..f6bbd4748f29 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -350,6 +350,12 @@ pub(crate) enum PolarsSQLFunctions { /// SELECT column_2 from df WHERE STARTS_WITH(column_1, 'a'); /// ``` StartsWith, + /// SQL 'strpos' function + /// Returns the index of the given substring in the target string. + /// ```sql + /// SELECT STRPOS(column_1,'xyz') from df; + /// ``` + StrPos, /// SQL 'substr' function /// Returns a portion of the data (first character = 0) in the range. /// \[start, start + length] @@ -673,6 +679,7 @@ impl PolarsSQLFunctions { "lower" => Self::Lower, "ltrim" => Self::LTrim, "octet_length" => Self::OctetLength, + "strpos" => Self::StrPos, "regexp_like" => Self::RegexpLike, "replace" => Self::Replace, "reverse" => Self::Reverse, @@ -844,6 +851,10 @@ impl SQLFunctionVisitor<'_> { _ => polars_bail!(InvalidOperation: "Invalid number of arguments for LTrim: {}", function.args.len()), }, OctetLength => self.visit_unary(|e| e.str().len_bytes()), + StrPos => { + // note: 1-indexed, not 0-indexed, and returns zero if match not found + self.visit_binary(|expr, substring| (expr.str().find(substring, true) + lit(1u32)).fill_null(0u32)) + }, RegexpLike => match function.args.len() { 2 => self.visit_binary(|e, s| e.str().contains(s, true)), 3 => self.try_visit_ternary(|e, pat, flags| { diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 9e5e5562e46b..a3e8834b1f23 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -139,6 +139,15 @@ impl SQLExprVisitor<'_> { escape_char, } => self.visit_like(*negated, expr, pattern, escape_char, true), SQLExpr::Nested(expr) => self.visit_expr(expr), + SQLExpr::Position { expr, r#in } => Ok( + // note: SQL is 1-indexed, not 0-indexed + (self + .visit_expr(r#in)? + .str() + .find(self.visit_expr(expr)?, true) + + lit(1u32)) + .fill_null(0u32), + ), SQLExpr::RLike { // note: parses both RLIKE and REGEXP negated, @@ -178,8 +187,8 @@ impl SQLExprVisitor<'_> { } let mut lf = self.ctx.execute_query_no_ctes(subquery)?; - let schema = lf.schema()?; + if restriction == SubqueryRestriction::SingleColumn { if schema.len() != 1 { polars_bail!(InvalidOperation: "SQL subquery will return more than one column"); @@ -194,7 +203,6 @@ impl SQLExprVisitor<'_> { if let Some((old_name, _)) = schema_entry { let new_name = String::from(old_name.as_str()) + rand_string.as_str(); lf = lf.rename([old_name.to_string()], [new_name.clone()]); - return Ok(Expr::SubPlan( SpecialEq::new(Arc::new(lf.logical_plan)), vec![new_name], diff --git a/py-polars/docs/source/reference/expressions/string.rst b/py-polars/docs/source/reference/expressions/string.rst index f055e3807ed6..831edce162b0 100644 --- a/py-polars/docs/source/reference/expressions/string.rst +++ b/py-polars/docs/source/reference/expressions/string.rst @@ -21,6 +21,7 @@ The following methods are available under the `expr.str` attribute. Expr.str.extract Expr.str.extract_all Expr.str.extract_groups + Expr.str.find Expr.str.json_decode Expr.str.json_extract Expr.str.json_path_match diff --git a/py-polars/tests/unit/sql/test_strings.py b/py-polars/tests/unit/sql/test_strings.py index a7517f6aa49a..d1f9517c2fd6 100644 --- a/py-polars/tests/unit/sql/test_strings.py +++ b/py-polars/tests/unit/sql/test_strings.py @@ -6,6 +6,7 @@ import polars as pl from polars.exceptions import ComputeError, InvalidOperationError +from polars.testing import assert_frame_equal # TODO: Do not rely on I/O for these tests @@ -174,6 +175,62 @@ def test_string_like(pattern: str, like: str, expected: list[int]) -> None: assert res == expected +def test_string_position() -> None: + df = pl.Series( + name="city", + values=["Dubai", "Abu Dhabi", "Sharjah", "Al Ain", "Ajman", "Ras Al Khaimah"], + ).to_frame() + + with pl.SQLContext(cities=df, eager_execution=True) as ctx: + res = ctx.execute( + """ + SELECT + POSITION('a' IN city) AS a_lc1, + POSITION('A' IN city) AS a_uc1, + STRPOS(city,'a') AS a_lc2, + STRPOS(city,'A') AS a_uc2, + FROM cities + """ + ) + expected_lc = [4, 7, 3, 0, 4, 2] + expected_uc = [0, 1, 0, 1, 1, 5] + + assert res.to_dict(as_series=False) == { + "a_lc1": expected_lc, + "a_uc1": expected_uc, + "a_lc2": expected_lc, + "a_uc2": expected_uc, + } + + df = pl.DataFrame({"txt": ["AbCdEXz", "XyzFDkE"]}) + with pl.SQLContext(txt=df) as ctx: + res = ctx.execute( + """ + SELECT + txt, + POSITION('E' IN txt) AS match_E, + STRPOS(txt,'X') AS match_X + FROM txt + """, + eager=True, + ) + assert_frame_equal( + res, + pl.DataFrame( + data={ + "txt": ["AbCdEXz", "XyzFDkE"], + "match_E": [5, 7], + "match_X": [6, 1], + }, + schema={ + "txt": pl.String, + "match_E": pl.UInt32, + "match_X": pl.UInt32, + }, + ), + ) + + def test_string_replace() -> None: df = pl.DataFrame({"words": ["Yemeni coffee is the best coffee", "", None]}) with pl.SQLContext(df=df) as ctx: