Skip to content

Commit

Permalink
window functions: sort cached groups if needed (#4184)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 30, 2022
1 parent 3e665fd commit b7ec308
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
6 changes: 5 additions & 1 deletion polars/polars-core/src/frame/groupby/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,11 @@ impl GroupsProxy {
#[cfg(feature = "private")]
pub fn sort(&mut self) {
match self {
GroupsProxy::Idx(groups) => groups.sort(),
GroupsProxy::Idx(groups) => {
if !groups.is_sorted() {
groups.sort()
}
}
GroupsProxy::Slice { groups, rolling } => {
if !*rolling {
groups.sort_unstable_by_key(|[first, _]| *first);
Expand Down
12 changes: 8 additions & 4 deletions polars/polars-lazy/src/physical_plan/expressions/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,8 @@ impl PhysicalExpr for WindowExpr {
.all(|s| matches!(s.is_sorted(), IsSorted::Ascending | IsSorted::Descending));
let explicit_list_agg = self.is_explicit_list_agg();

let create_groups = || {
// if we flatten this column we need to make sure the groups are sorted.
let sorted = self.options.explode ||
// if we flatten this column we need to make sure the groups are sorted.
let sort_groups = self.options.explode ||
// if not
// `col().over()`
// and not
Expand All @@ -240,7 +239,9 @@ impl PhysicalExpr for WindowExpr {
// and keys are sorted
// we may optimize with explode call
(!self.is_simple_column_expr() && !explicit_list_agg && sorted_keys && !self.is_aggregation());
let gb = df.groupby_with_series(groupby_columns.clone(), true, sorted)?;

let create_groups = || {
let gb = df.groupby_with_series(groupby_columns.clone(), true, sort_groups)?;
let out: Result<GroupsProxy> = Ok(gb.take_groups());
out
};
Expand All @@ -261,6 +262,9 @@ impl PhysicalExpr for WindowExpr {
if df.height() > 0 {
assert!(!gt.is_empty());
};
if sort_groups {
gt.sort()
}

// We take now, but it is important that we set this before we return!
// a next windows function may get this cached key and get an empty if this
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,25 @@ def test_count_window() -> None:
.with_column(pl.count().over("a"))["count"]
.to_list()
) == [2, 2, 1]


def test_window_cached_keys_sorted_update_4183() -> None:
df = pl.DataFrame(
{
"customer_ID": [
"0",
"0",
"1",
],
"date": [1, 2, 3],
}
)
assert df.sort(by=["customer_ID", "date"]).select(
[
pl.count("date").over(pl.col("customer_ID")).alias("count"),
pl.col("date")
.rank(method="ordinal")
.over(pl.col("customer_ID"))
.alias("rank"),
]
).to_dict(False) == {"count": [2, 2, 1], "rank": [1, 2, 1]}

0 comments on commit b7ec308

Please sign in to comment.