Skip to content

Commit

Permalink
feat(rust, python): accept expression in str.extract_all (#5742)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 8, 2022
1 parent eb57d67 commit f05e4da
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 16 deletions.
4 changes: 2 additions & 2 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,8 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
Extract { pat, group_index } => {
map!(strings::extract, &pat, group_index)
}
ExtractAll(pat) => {
map!(strings::extract_all, &pat)
ExtractAll => {
map_as_slice!(strings::extract_all)
}
CountMatch(pat) => {
map!(strings::count_match, &pat)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl FunctionExpr {
match s {
Contains { .. } | EndsWith(_) | StartsWith(_) => with_dtype(DataType::Boolean),
Extract { .. } => same_type(),
ExtractAll(_) => with_dtype(DataType::List(Box::new(DataType::Utf8))),
ExtractAll => with_dtype(DataType::List(Box::new(DataType::Utf8))),
CountMatch(_) => with_dtype(DataType::UInt32),
#[cfg(feature = "string_justify")]
Zfill { .. } | LJust { .. } | RJust { .. } => same_type(),
Expand Down
20 changes: 15 additions & 5 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub enum StringFunction {
width: usize,
fillchar: char,
},
ExtractAll(String),
ExtractAll,
CountMatch(String),
#[cfg(feature = "temporal")]
Strptime(StrpTimeOptions),
Expand Down Expand Up @@ -68,7 +68,7 @@ impl Display for StringFunction {
StringFunction::LJust { .. } => "str.ljust",
#[cfg(feature = "string_justify")]
StringFunction::RJust { .. } => "rjust",
StringFunction::ExtractAll(_) => "extract_all",
StringFunction::ExtractAll => "extract_all",
StringFunction::CountMatch(_) => "count_match",
#[cfg(feature = "temporal")]
StringFunction::Strptime(_) => "strptime",
Expand Down Expand Up @@ -176,11 +176,21 @@ pub(super) fn rstrip(s: &Series, matches: Option<char>) -> PolarsResult<Series>
}
}

pub(super) fn extract_all(s: &Series, pat: &str) -> PolarsResult<Series> {
let pat = pat.to_string();
pub(super) fn extract_all(args: &[Series]) -> PolarsResult<Series> {
let s = &args[0];
let pat = &args[1];

let ca = s.utf8()?;
ca.extract_all(&pat).map(|ca| ca.into_series())
let pat = pat.utf8()?;

if pat.len() == 1 {
let pat = pat
.get(0)
.ok_or_else(|| PolarsError::ComputeError("Expected a pattern got null".into()))?;
ca.extract_all(pat).map(|ca| ca.into_series())
} else {
ca.extract_all_many(pat).map(|ca| ca.into_series())
}
}

pub(super) fn count_match(s: &Series, pat: &str) -> PolarsResult<Series> {
Expand Down
6 changes: 3 additions & 3 deletions polars/polars-lazy/polars-plan/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ impl StringNameSpace {
}

/// Extract each successive non-overlapping match in an individual string as an array
pub fn extract_all(self, pat: &str) -> Expr {
let pat = pat.to_string();
self.0.map_private(StringFunction::ExtractAll(pat).into())
pub fn extract_all(self, pat: Expr) -> Expr {
self.0
.map_many_private(StringFunction::ExtractAll.into(), &[pat], false)
}

/// Count all successive non-overlapping regex matches.
Expand Down
28 changes: 28 additions & 0 deletions polars/polars-ops/src/chunked_array/strings/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,34 @@ pub trait Utf8NameSpaceImpl: AsUtf8 {
Ok(builder.finish())
}

/// Extract each successive non-overlapping regex match in an individual string as an array
fn extract_all_many(&self, pat: &Utf8Chunked) -> PolarsResult<ListChunked> {
let ca = self.as_utf8();
if ca.len() != pat.len() {
return Err(PolarsError::ComputeError(
"pattern's length does not match that of the argument Series".into(),
));
}

let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size());

for (opt_s, opt_pat) in ca.into_iter().zip(pat.into_iter()) {
match (opt_s, opt_pat) {
(_, None) | (None, _) => builder.append_null(),
(Some(s), Some(pat)) => {
let reg = Regex::new(pat)?;
let mut iter = reg.find_iter(s).map(|m| m.as_str()).peekable();
if iter.peek().is_some() {
builder.append_values_iter(iter);
} else {
builder.append_null()
}
}
}
}
Ok(builder.finish())
}

/// Count all successive non-overlapping regex matches.
fn count_match(&self, pat: &str) -> PolarsResult<UInt32Chunked> {
let ca = self.as_utf8();
Expand Down
5 changes: 3 additions & 2 deletions py-polars/polars/internals/expr/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ def extract(self, pattern: str, group_index: int = 1) -> pli.Expr:
"""
return pli.wrap_expr(self._pyexpr.str_extract(pattern, group_index))

def extract_all(self, pattern: str) -> pli.Expr:
def extract_all(self, pattern: str | pli.Expr) -> pli.Expr:
r"""
Extracts all matches for the given regex pattern.
Expand Down Expand Up @@ -830,7 +830,8 @@ def extract_all(self, pattern: str) -> pli.Expr:
└────────────────┘
"""
return pli.wrap_expr(self._pyexpr.str_extract_all(pattern))
pattern = pli.expr_to_lit_or_expr(pattern, str_to_lit=True)
return pli.wrap_expr(self._pyexpr.str_extract_all(pattern._pyexpr))

def count_match(self, pattern: str) -> pli.Expr:
r"""
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/series/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def extract(self, pattern: str, group_index: int = 1) -> pli.Series:
"""

def extract_all(self, pattern: str) -> pli.Series:
def extract_all(self, pattern: str | pli.Series) -> pli.Series:
r"""
Extracts all matches for the given regex pattern.
Expand Down
4 changes: 2 additions & 2 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -751,8 +751,8 @@ impl PyExpr {
self.inner.clone().str().extract(pat, group_index).into()
}

pub fn str_extract_all(&self, pat: &str) -> PyExpr {
self.inner.clone().str().extract_all(pat).into()
pub fn str_extract_all(&self, pat: PyExpr) -> PyExpr {
self.inner.clone().str().extract_all(pat.inner).into()
}

pub fn count_match(&self, pat: &str) -> PyExpr {
Expand Down
5 changes: 5 additions & 0 deletions py-polars/tests/unit/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ def test_extract_all_count() -> None:
assert df["foo"].str.count_match(r"a").dtype == pl.UInt32


def test_extract_all_many() -> None:
df = pl.DataFrame({"foo": ["ab", "abc", "abcd"], "re": ["a", "bc", "a.c"]})
assert df["foo"].str.extract_all(df["re"]).to_list() == [["a"], ["bc"], ["abc"]]


def test_zfill() -> None:
df = pl.DataFrame(
{
Expand Down

0 comments on commit f05e4da

Please sign in to comment.