Skip to content

Commit

Permalink
fix[expr]: sort_by ensure groups are synced (#4392)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 13, 2022
1 parent 14639ae commit b14e6fb
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
25 changes: 22 additions & 3 deletions polars/polars-lazy/src/physical_plan/expressions/sortby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,17 @@ impl PhysicalExpr for SortByExpr {
state: &ExecutionState,
) -> Result<AggregationContext<'a>> {
let mut ac_in = self.input.evaluate_on_groups(df, groups, state)?;

let reverse = prepare_reverse(&self.reverse, self.by.len());

let groups = if self.by.len() == 1 {
let (groups, ordered_by_group_operation) = if self.by.len() == 1 {
let mut ac_sort_by = self.by[0].evaluate_on_groups(df, groups, state)?;
let sort_by_s = ac_sort_by.flat_naive().into_owned();

let ordered_by_group_operation = matches!(
ac_sort_by.update_groups,
UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen
);
let groups = ac_sort_by.groups();

let groups = groups
Expand Down Expand Up @@ -122,7 +128,7 @@ impl PhysicalExpr for SortByExpr {
})
.collect();

GroupsProxy::Idx(groups)
(GroupsProxy::Idx(groups), ordered_by_group_operation)
} else {
let mut ac_sort_by = self
.by
Expand All @@ -133,6 +139,11 @@ impl PhysicalExpr for SortByExpr {
.iter()
.map(|s| s.flat_naive().into_owned())
.collect::<Vec<_>>();

let ordered_by_group_operation = matches!(
ac_sort_by[0].update_groups,
UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen
);
let groups = ac_sort_by[0].groups();

let groups = groups
Expand Down Expand Up @@ -168,9 +179,17 @@ impl PhysicalExpr for SortByExpr {
})
.collect();

GroupsProxy::Idx(groups)
(GroupsProxy::Idx(groups), ordered_by_group_operation)
};

// if the rhs is already aggregated once,
// it is reordered by the groupby operation
// we must ensure that we are as well.
if ordered_by_group_operation {
let s = ac_in.aggregated();
ac_in.with_series(s.explode().unwrap(), false);
}

ac_in.with_groups(groups);
Ok(ac_in)
}
Expand Down
20 changes: 20 additions & 0 deletions py-polars/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,23 @@ def test_groupby_signed_transmutes() -> None:
"foo": [-1, -2, -3, -4, -5],
"bar": [500.0, 600.0, 700.0, 800.0, 900.0],
}


def test_argsort_sort_by_groups_update__4360() -> None:
df = pl.DataFrame(
{
"group": ["a"] * 3 + ["b"] * 3,
"col1": [1, 2, 3, 300, 200, 100],
"col2": [1, 2, 3, 300, 200, 100],
}
)
assert (
df.select(
[
pl.col("col1")
.sort_by(pl.col("col2").arg_sort())
.over("group")
.alias("1_argsort_2"),
]
)
)["1_argsort_2"].to_list() == [1, 2, 3, 300, 200, 100]

0 comments on commit b14e6fb

Please sign in to comment.