Skip to content

Commit

Permalink
fix rolling groupby ordering with 'by' argument (#3720)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 17, 2022
1 parent af3327e commit 174a213
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 5 deletions.
17 changes: 12 additions & 5 deletions polars/polars-time/src/groupby/dynamic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,15 +374,22 @@ impl Wrap<&DataFrame> {
})
.collect()
});
GroupsProxy::Idx(groupsidx)
let mut groups = GroupsProxy::Idx(groupsidx);
groups.sort();
groups
};
let dt = dt.cast(time_type).unwrap();

// the ordering has changed due to the groupby
if !by.is_empty() {
for key in by.iter_mut() {
*key = unsafe { key.agg_first(&groups) };
unsafe {
for key in by.iter_mut() {
*key = key.agg_first(&groups);
}
}
}
dt.cast(time_type).map(|s| (s, by, groups))
};

Ok((dt, by, groups))
}
}

Expand Down
46 changes: 46 additions & 0 deletions py-polars/tests/test_datelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,3 +1004,49 @@ def test_unique_counts_on_dates() -> None:
"dt_ms": [3],
"date": [3],
}


def test_groupby_rolling_by_ordering() -> None:
# we must check that the keys still match the time labels after the rolling window
# with a `by` argument.
df = pl.DataFrame(
{
"dt": [
datetime(2022, 1, 1, 0, 1),
datetime(2022, 1, 1, 0, 2),
datetime(2022, 1, 1, 0, 3),
datetime(2022, 1, 1, 0, 4),
datetime(2022, 1, 1, 0, 5),
datetime(2022, 1, 1, 0, 6),
datetime(2022, 1, 1, 0, 7),
],
"key": ["A", "A", "B", "B", "A", "B", "A"],
"val": [1, 1, 1, 1, 1, 1, 1],
}
)

assert df.groupby_rolling(
index_column="dt",
period="2m",
closed="both",
offset="-1m",
by="key",
).agg(
[
pl.col("val").sum().alias("sum val"),
]
).to_dict(
False
) == {
"key": ["A", "A", "B", "B", "A", "B", "A"],
"dt": [
datetime(2022, 1, 1, 0, 1),
datetime(2022, 1, 1, 0, 2),
datetime(2022, 1, 1, 0, 3),
datetime(2022, 1, 1, 0, 4),
datetime(2022, 1, 1, 0, 5),
datetime(2022, 1, 1, 0, 6),
datetime(2022, 1, 1, 0, 7),
],
"sum val": [2, 2, 2, 2, 1, 1, 1],
}

0 comments on commit 174a213

Please sign in to comment.