Skip to content

Commit

Permalink
fold regex expand (#4181)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 29, 2022
1 parent 56cc086 commit 08f6f73
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ where
F: Fn(Series, Series) -> Result<Series> + Send + Sync + Clone,
{
let mut exprs = exprs.as_ref().to_vec();
if exprs.iter().any(has_wildcard) {
if exprs.iter().any(|e| has_wildcard(e) | has_regex(e)) {
exprs.push(acc);

let function = SpecialEq::new(Arc::new(move |series: &mut [Series]| {
Expand Down
8 changes: 8 additions & 0 deletions polars/polars-lazy/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ pub(crate) fn has_wildcard(current_expr: &Expr) -> bool {
has_expr(current_expr, |e| matches!(e, Expr::Wildcard))
}

// this one is used so much that it has its own function, to reduce inlining
pub(crate) fn has_regex(current_expr: &Expr) -> bool {
has_expr(current_expr, |e| match e {
Expr::Column(name) => name.starts_with('^') && name.ends_with('$'),
_ => false,
})
}

pub(crate) fn has_nth(current_expr: &Expr) -> bool {
has_expr(current_expr, |e| matches!(e, Expr::Nth(_)))
}
Expand Down
20 changes: 20 additions & 0 deletions py-polars/tests/test_expr_multi_cols.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,23 @@ def test_exclude_name_from_dtypes() -> None:
assert df.with_column(pl.col(pl.Utf8).exclude("a").suffix("_foo")).frame_equal(
pl.DataFrame({"a": ["a"], "b": ["b"], "b_foo": ["b"]})
)


def test_fold_regex_expand() -> None:
df = pl.DataFrame(
{
"x": [0, 1, 2],
"y_1": [1.1, 2.2, 3.3],
"y_2": [1.0, 2.5, 3.5],
}
)
assert df.with_column(
pl.fold(acc=pl.lit(0), f=lambda acc, x: acc + x, exprs=pl.col("^y_.*$")).alias(
"y_sum"
),
).to_dict(False) == {
"x": [0, 1, 2],
"y_1": [1.1, 2.2, 3.3],
"y_2": [1.0, 2.5, 3.5],
"y_sum": [2.1, 4.7, 6.8],
}

0 comments on commit 08f6f73

Please sign in to comment.