Skip to content

Commit

Permalink
fix[rust]: sort_by(argsort) window expression did not use updated gro…
Browse files Browse the repository at this point in the history
…up (#4426)
  • Loading branch information
ritchie46 committed Aug 15, 2022
1 parent 9261e42 commit 34c5e89
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 24 deletions.
6 changes: 2 additions & 4 deletions polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,10 +403,8 @@ impl<'a> AggregationContext<'a> {
let out = unsafe { s.agg_list(&self.groups) };
self.state = AggState::AggregatedList(out.clone());

if !self.sorted {
self.sorted = true;
self.update_groups = UpdateGroups::WithGroupsLen;
};
self.sorted = true;
self.update_groups = UpdateGroups::WithGroupsLen;
out
}
AggState::AggregatedList(s) | AggState::AggregatedFlat(s) => s,
Expand Down
15 changes: 8 additions & 7 deletions polars/polars-lazy/src/physical_plan/expressions/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ impl PhysicalExpr for WindowExpr {
//
// - 3.3. MAP to original locations
// This will be done for list aggregations that are not explicitly aggregated as list
// `(col("x").sum() * col("y")).over("groups")`
// `(col("x").sum() * col("y")).over("groups")
// This can be used to reverse, sort, shuffle etc. the values in a group

// 4. select the final column and return
let groupby_columns = self
Expand Down Expand Up @@ -362,8 +363,8 @@ impl PhysicalExpr for WindowExpr {
}
}
GroupsProxy::Slice { groups, .. } => {
for g in groups {
original_idx.extend(g[0]..g[0] + g[1])
for &[first, len] in groups {
original_idx.extend(first..first + len)
}
}
};
Expand All @@ -390,8 +391,8 @@ impl PhysicalExpr for WindowExpr {
}
}
GroupsProxy::Slice { groups, .. } => {
for g in groups {
idx_mapping.extend((g[0]..g[0] + g[1]).zip(&mut original_idx));
for &[first, len] in groups {
idx_mapping.extend((first..first + len).zip(&mut original_idx));
}
}
}
Expand All @@ -406,8 +407,8 @@ impl PhysicalExpr for WindowExpr {
}
}
GroupsProxy::Slice { groups, .. } => {
for g in groups {
idx_mapping.extend((g[0]..g[0] + g[1]).zip(&mut original_idx));
for &[first, len] in groups {
idx_mapping.extend((first..first + len).zip(&mut original_idx));
}
}
}
Expand Down
34 changes: 21 additions & 13 deletions py-polars/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,21 +118,29 @@ def test_groupby_signed_transmutes() -> None:
def test_argsort_sort_by_groups_update__4360() -> None:
df = pl.DataFrame(
{
"group": ["a"] * 3 + ["b"] * 3,
"col1": [1, 2, 3, 300, 200, 100],
"col2": [1, 2, 3, 300, 200, 100],
"group": ["a"] * 3 + ["b"] * 3 + ["c"] * 3,
"col1": [1, 2, 3] * 3,
"col2": [1, 2, 3, 3, 2, 1, 2, 3, 1],
}
)
assert (
df.select(
[
pl.col("col1")
.sort_by(pl.col("col2").arg_sort())
.over("group")
.alias("1_argsort_2"),
]
)
)["1_argsort_2"].to_list() == [1, 2, 3, 300, 200, 100]

out = df.with_column(
pl.col("col2").arg_sort().over("group").alias("col2_argsort")
).with_columns(
[
pl.col("col1")
.sort_by(pl.col("col2_argsort"))
.over("group")
.alias("result_a"),
pl.col("col1")
.sort_by(pl.col("col2").arg_sort())
.over("group")
.alias("result_b"),
]
)

pl.testing.assert_series_equal(out["result_a"], out["result_b"], check_names=False)
assert out["result_a"].to_list() == [1, 2, 3, 3, 2, 1, 2, 3, 1]


def test_unique_order() -> None:
Expand Down

0 comments on commit 34c5e89

Please sign in to comment.