Skip to content

Commit

Permalink
feat(rust,python): add new str.find expression, returning the index…
Browse files Browse the repository at this point in the history
… of a regex pattern or literal substring (#13561)
  • Loading branch information
alexander-beedie committed Jan 10, 2024
1 parent 00ce1f7 commit 94306a4
Show file tree
Hide file tree
Showing 9 changed files with 404 additions and 30 deletions.
2 changes: 1 addition & 1 deletion crates/polars-core/src/chunked_array/ops/apply.rs
Expand Up @@ -14,7 +14,7 @@ impl<T> ChunkedArray<T>
where
T: PolarsDataType,
{
// Applies a function to all elements , regardless of whether they
// Applies a function to all elements, regardless of whether they
// are null or not, after which the null mask is copied from the
// original array.
pub fn apply_values_generic<'a, U, K, F>(&'a self, mut op: F) -> ChunkedArray<U>
Expand Down
63 changes: 61 additions & 2 deletions crates/polars-ops/src/chunked_array/strings/namespace.rs
Expand Up @@ -151,6 +151,46 @@ pub trait StringNameSpaceImpl: AsString {
}
}

fn find_chunked(
&self,
pat: &StringChunked,
literal: bool,
strict: bool,
) -> PolarsResult<UInt32Chunked> {
let ca = self.as_string();
if pat.len() == 1 {
return if let Some(pat) = pat.get(0) {
if literal {
ca.find_literal(pat)
} else {
ca.find(pat, strict)
}
} else {
Ok(UInt32Chunked::full_null(ca.name(), ca.len()))
};
} else if ca.len() == 1 && ca.null_count() == 1 {
return Ok(UInt32Chunked::full_null(ca.name(), ca.len().max(pat.len())));
}
if literal {
Ok(broadcast_binary_elementwise(
ca,
pat,
|src: Option<&str>, pat: Option<&str>| src?.find(pat?).map(|idx| idx as u32),
))
} else {
// note: sqrt(n) regex cache is not too small, not too large.
let mut rx_cache = FastFixedCache::new((ca.len() as f64).sqrt() as usize);
let matcher = |src: Option<&str>, pat: Option<&str>| -> PolarsResult<Option<u32>> {
if let (Some(src), Some(pat)) = (src, pat) {
let rx = rx_cache.try_get_or_insert_with(pat, |p| Regex::new(p))?;
return Ok(rx.find(src).map(|m| m.start() as u32));
}
Ok(None)
};
broadcast_try_binary_elementwise(ca, pat, matcher)
}
}

/// Get the length of the string values as number of chars.
fn str_len_chars(&self) -> UInt32Chunked {
let ca = self.as_string();
Expand Down Expand Up @@ -200,10 +240,8 @@ pub trait StringNameSpaceImpl: AsString {
/// Check if strings contain a regex pattern.
fn contains(&self, pat: &str, strict: bool) -> PolarsResult<BooleanChunked> {
let ca = self.as_string();

let res_reg = Regex::new(pat);
let opt_reg = if strict { Some(res_reg?) } else { res_reg.ok() };

let out: BooleanChunked = if let Some(reg) = opt_reg {
ca.apply_values_generic(|s| reg.is_match(s))
} else {
Expand All @@ -220,6 +258,27 @@ pub trait StringNameSpaceImpl: AsString {
self.contains(regex::escape(lit).as_str(), true)
}

/// Return the index position of a literal substring in the target string.
fn find_literal(&self, lit: &str) -> PolarsResult<UInt32Chunked> {
self.find(regex::escape(lit).as_str(), true)
}

/// Return the index position of a regular expression substring in the target string.
fn find(&self, pat: &str, strict: bool) -> PolarsResult<UInt32Chunked> {
let ca = self.as_string();
match Regex::new(pat) {
Ok(rx) => {
Ok(ca.apply_generic(|opt_s| {
opt_s.and_then(|s| rx.find(s)).map(|m| m.start() as u32)
}))
},
Err(_) if !strict => Ok(UInt32Chunked::full_null(ca.name(), ca.len())),
Err(e) => Err(PolarsError::ComputeError(
format!("Invalid regular expression: {}", e).into(),
)),
}
}

/// Replace the leftmost regex-matched (sub)string with another string
fn replace<'a>(&'a self, pat: &str, val: &str) -> PolarsResult<StringChunked> {
let reg = Regex::new(pat)?;
Expand Down
16 changes: 15 additions & 1 deletion crates/polars-plan/src/dsl/function_expr/strings.rs
Expand Up @@ -47,6 +47,10 @@ pub enum StringFunction {
dtype: DataType,
pat: String,
},
Find {
literal: bool,
strict: bool,
},
#[cfg(feature = "string_to_integer")]
ToInteger(u32, bool),
LenBytes,
Expand Down Expand Up @@ -135,6 +139,7 @@ impl StringFunction {
ExtractGroups { dtype, .. } => mapper.with_dtype(dtype.clone()),
#[cfg(feature = "string_to_integer")]
ToInteger { .. } => mapper.with_dtype(DataType::Int64),
Find { .. } => mapper.with_dtype(DataType::UInt32),
#[cfg(feature = "extract_jsonpath")]
JsonDecode { dtype, .. } => mapper.with_opt_dtype(dtype.clone()),
LenBytes => mapper.with_dtype(DataType::UInt32),
Expand Down Expand Up @@ -207,6 +212,7 @@ impl Display for StringFunction {
ExtractGroups { .. } => "extract_groups",
#[cfg(feature = "string_to_integer")]
ToInteger { .. } => "to_integer",
Find { .. } => "find",
#[cfg(feature = "extract_jsonpath")]
JsonDecode { .. } => "json_decode",
LenBytes => "len_bytes",
Expand Down Expand Up @@ -291,6 +297,7 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
ExtractGroups { pat, dtype } => {
map!(strings::extract_groups, &pat, &dtype)
},
Find { literal, strict } => map_as_slice!(strings::find, literal, strict),
LenBytes => map!(strings::len_bytes),
LenChars => map!(strings::len_chars),
#[cfg(feature = "string_pad")]
Expand Down Expand Up @@ -427,6 +434,14 @@ pub(super) fn contains(s: &[Series], literal: bool, strict: bool) -> PolarsResul
.map(|ok| ok.into_series())
}

#[cfg(feature = "regex")]
pub(super) fn find(s: &[Series], literal: bool, strict: bool) -> PolarsResult<Series> {
let ca = s[0].str()?;
let pat = s[1].str()?;
ca.find_chunked(pat, literal, strict)
.map(|ok| ok.into_series())
}

pub(super) fn ends_with(s: &[Series]) -> PolarsResult<Series> {
let ca = &s[0].str()?.as_binary();
let suffix = &s[1].str()?.as_binary();
Expand Down Expand Up @@ -845,7 +860,6 @@ pub(super) fn replace(s: &[Series], literal: bool, n: i64) -> PolarsResult<Serie
let column = &s[0];
let pat = &s[1];
let val = &s[2];

let all = n < 0;

let column = column.str()?;
Expand Down
28 changes: 28 additions & 0 deletions crates/polars-plan/src/dsl/string.rs
Expand Up @@ -192,6 +192,34 @@ impl StringNameSpace {
self.0.map_private(StringFunction::ZFill(length).into())
}

/// Find the index of a literal substring within another string value.
#[cfg(feature = "regex")]
pub fn find_literal(self, pat: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::Find {
literal: true,
strict: false,
}),
&[pat],
false,
true,
)
}

/// Find the index of a substring defined by a regular expressons within another string value.
#[cfg(feature = "regex")]
pub fn find(self, pat: Expr, strict: bool) -> Expr {
self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::Find {
literal: false,
strict,
}),
&[pat],
false,
true,
)
}

/// Extract each successive non-overlapping match in an individual string as an array
pub fn extract_all(self, pat: Expr) -> Expr {
self.0
Expand Down
114 changes: 104 additions & 10 deletions py-polars/polars/expr/string.py
Expand Up @@ -954,7 +954,7 @@ def contains(
self, pattern: str | Expr, *, literal: bool = False, strict: bool = True
) -> Expr:
"""
Check if string contains a substring that matches a regex.
Check if string contains a substring that matches a pattern.
Parameters
----------
Expand Down Expand Up @@ -995,18 +995,19 @@ def contains(
--------
starts_with : Check if string values start with a substring.
ends_with : Check if string values end with a substring.
find: Return the index of the first substring matching a pattern.
Examples
--------
>>> df = pl.DataFrame({"a": ["Crab", "cat and dog", "rab$bit", None]})
>>> df = pl.DataFrame({"txt": ["Crab", "cat and dog", "rab$bit", None]})
>>> df.select(
... pl.col("a"),
... pl.col("a").str.contains("cat|bit").alias("regex"),
... pl.col("a").str.contains("rab$", literal=True).alias("literal"),
... pl.col("txt"),
... pl.col("txt").str.contains("cat|bit").alias("regex"),
... pl.col("txt").str.contains("rab$", literal=True).alias("literal"),
... )
shape: (4, 3)
┌─────────────┬───────┬─────────┐
a ┆ regex ┆ literal │
txt ┆ regex ┆ literal │
│ --- ┆ --- ┆ --- │
│ str ┆ bool ┆ bool │
╞═════════════╪═══════╪═════════╡
Expand All @@ -1019,6 +1020,99 @@ def contains(
pattern = parse_as_expression(pattern, str_as_lit=True)
return wrap_expr(self._pyexpr.str_contains(pattern, literal, strict))

def find(
self, pattern: str | Expr, *, literal: bool = False, strict: bool = True
) -> Expr:
"""
Return the index position of the first substring matching a pattern.
If the pattern is not found, returns None.
Parameters
----------
pattern
A valid regular expression pattern, compatible with the `regex crate
<https://docs.rs/regex/latest/regex/>`_.
literal
Treat `pattern` as a literal string, not as a regular expression.
strict
Raise an error if the underlying pattern is not a valid regex,
otherwise mask out with a null value.
Notes
-----
To modify regular expression behaviour (such as case-sensitivity) with
flags, use the inline `(?iLmsuxU)` syntax. For example:
>>> pl.DataFrame({"s": ["AAA", "aAa", "aaa"]}).with_columns(
... default_match=pl.col("s").str.find("Aa"),
... insensitive_match=pl.col("s").str.find("(?i)Aa"),
... )
shape: (3, 3)
┌─────┬───────────────┬───────────────────┐
│ s ┆ default_match ┆ insensitive_match │
│ --- ┆ --- ┆ --- │
│ str ┆ u32 ┆ u32 │
╞═════╪═══════════════╪═══════════════════╡
│ AAA ┆ null ┆ 0 │
│ aAa ┆ 1 ┆ 0 │
│ aaa ┆ null ┆ 0 │
└─────┴───────────────┴───────────────────┘
See the regex crate's section on `grouping and flags
<https://docs.rs/regex/latest/regex/#grouping-and-flags>`_ for
additional information about the use of inline expression modifiers.
See Also
--------
contains : Check if string contains a substring that matches a regex.
Examples
--------
>>> df = pl.DataFrame(
... {
... "txt": ["Crab", "Lobster", None, "Crustaceon"],
... "pat": ["a[bc]", "b.t", "[aeiuo]", "(?i)A[BC]"],
... }
... )
Find the index of the first substring matching a regex or literal pattern:
>>> df.select(
... pl.col("txt"),
... pl.col("txt").str.find("a|e").alias("a|e (regex)"),
... pl.col("txt").str.find("e", literal=True).alias("e (lit)"),
... )
shape: (4, 3)
┌────────────┬─────────────┬─────────┐
│ txt ┆ a|e (regex) ┆ e (lit) │
│ --- ┆ --- ┆ --- │
│ str ┆ u32 ┆ u32 │
╞════════════╪═════════════╪═════════╡
│ Crab ┆ 2 ┆ null │
│ Lobster ┆ 5 ┆ 5 │
│ null ┆ null ┆ null │
│ Crustaceon ┆ 5 ┆ 7 │
└────────────┴─────────────┴─────────┘
Match against a pattern found in another column or (expression):
>>> df.with_columns(pl.col("txt").str.find(pl.col("pat")).alias("find_pat"))
shape: (4, 3)
┌────────────┬───────────┬──────────┐
│ txt ┆ pat ┆ find_pat │
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ u32 │
╞════════════╪═══════════╪══════════╡
│ Crab ┆ a[bc] ┆ 2 │
│ Lobster ┆ b.t ┆ 2 │
│ null ┆ [aeiuo] ┆ null │
│ Crustaceon ┆ (?i)A[BC] ┆ 5 │
└────────────┴───────────┴──────────┘
"""
pattern = parse_as_expression(pattern, str_as_lit=True)
return wrap_expr(self._pyexpr.str_find(pattern, literal, strict))

def ends_with(self, suffix: str | Expr) -> Expr:
"""
Check if string values end with a substring.
Expand Down Expand Up @@ -1298,8 +1392,8 @@ def extract(self, pattern: str, group_index: int = 1) -> Expr:
Parameters
----------
pattern
A valid regular expression pattern, compatible with the `regex crate
<https://docs.rs/regex/latest/regex/>`_.
A valid regular expression pattern containing at least one capture group,
compatible with the `regex crate <https://docs.rs/regex/latest/regex/>`_.
group_index
Index of the targeted capture group.
Group 0 means the whole pattern, the first group begins at index 1.
Expand Down Expand Up @@ -1466,8 +1560,8 @@ def extract_groups(self, pattern: str) -> Expr:
Parameters
----------
pattern
A valid regular expression pattern, compatible with the `regex crate
<https://docs.rs/regex/latest/regex/>`_.
A valid regular expression pattern containing at least one capture group,
compatible with the `regex crate <https://docs.rs/regex/latest/regex/>`_.
Notes
-----
Expand Down
22 changes: 10 additions & 12 deletions py-polars/polars/io/csv/functions.py
Expand Up @@ -194,18 +194,16 @@ def read_csv(
--------
>>> pl.read_csv("data.csv", separator="|") # doctest: +SKIP
Reproducible example using BytesIO object, parsing dates.
>>> import io # doctest: +SKIP
>>> source = io.BytesIO(
... (
... "ID,Name,Birthday\n"
... "1,Alice,1995-07-12\n"
... "2,Bob,1990-09-20\n"
... "3,Charlie,2002-03-08"
... ).encode()
... ) # doctest: +SKIP
>>> pl.read_csv(source, try_parse_dates=True) # doctest: +SKIP
Demonstrate use against a BytesIO object, parsing string dates.
>>> from io import BytesIO
>>> data = BytesIO(
... b"ID,Name,Birthday\n"
... b"1,Alice,1995-07-12\n"
... b"2,Bob,1990-09-20\n"
... b"3,Charlie,2002-03-08\n"
... )
>>> pl.read_csv(data, try_parse_dates=True)
shape: (3, 3)
┌─────┬─────────┬────────────┐
│ ID ┆ Name ┆ Birthday │
Expand Down

0 comments on commit 94306a4

Please sign in to comment.