Skip to content

Commit

Permalink
fix[rust]: ensure left argument of functions is iterated first (#4447)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 16, 2022
1 parent 39fcfc5 commit 32e33fe
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
13 changes: 10 additions & 3 deletions polars/polars-lazy/src/logical_plan/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ macro_rules! push_expr {
$push(falsy);
$push(predicate)
}
AnonymousFunction { input, .. } => input.$iter().for_each(|e| $push(e)),
Function { input, .. } => input.$iter().for_each(|e| $push(e)),
// we iterate in reverse order, so that the lhs is popped first and will be found
// as the root columns/ input columns by `_suffix` and `_keep_name` etc.
AnonymousFunction { input, .. } => input.$iter().rev().for_each(|e| $push(e)),
Function { input, .. } => input.$iter().rev().for_each(|e| $push(e)),
Shift { input, .. } => $push(input),
Reverse(e) => $push(e),
Duplicated(e) => $push(e),
Expand Down Expand Up @@ -216,7 +218,12 @@ impl AExpr {
push(falsy);
push(predicate)
}
AnonymousFunction { input, .. } | Function { input, .. } => input.iter().for_each(push),
AnonymousFunction { input, .. } | Function { input, .. } =>
// we iterate in reverse order, so that the lhs is popped first and will be found
// as the root columns/ input columns by `_suffix` and `_keep_name` etc.
{
input.iter().rev().for_each(push)
}
Shift { input, .. } => push(input),
Reverse(e) => push(e),
Duplicated(e) => push(e),
Expand Down
20 changes: 20 additions & 0 deletions py-polars/tests/test_expr_multi_cols.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,23 @@ def test_argsort_argument_expansion() -> None:
assert df.select(
pl.all().exclude("sort_order").sort_by(pl.col("sort_order")).arg_sort()
).to_dict(False) == {"col1": [2, 1, 0], "col2": [2, 1, 0]}


def test_append_root_columns() -> None:
df = pl.DataFrame(
{
"col1": [1, 2],
"col2": [10, 20],
"other": [100, 200],
}
)
assert (
df.select(
[
pl.col("col2").append(pl.col("other")),
pl.col("col1").append(pl.col("other")).keep_name(),
pl.col("col1").append(pl.col("other")).prefix("prefix_"),
pl.col("col1").append(pl.col("other")).suffix("_suffix"),
]
)
).columns == ["col2", "col1", "prefix_col1", "col1_suffix"]

0 comments on commit 32e33fe

Please sign in to comment.