Skip to content

Commit

Permalink
rolling groupby fix index column output order (#3806)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 25, 2022
1 parent f3f2db2 commit d919f4a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 14 deletions.
20 changes: 13 additions & 7 deletions polars/polars-time/src/groupby/dynamic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,10 +332,10 @@ impl Wrap<&DataFrame> {
tu: TimeUnit,
time_type: &DataType,
) -> Result<(Series, Vec<Series>, GroupsProxy)> {
let dt = dt.rechunk();
let dt = dt.datetime().unwrap();
let mut dt = dt.rechunk();

let groups = if by.is_empty() {
let dt = dt.datetime().unwrap();
let vals = dt.downcast_iter().next().unwrap();
let ts = vals.values().as_slice();
GroupsProxy::Slice {
Expand All @@ -353,13 +353,21 @@ impl Wrap<&DataFrame> {
.0
.groupby_with_series(by.clone(), true, true)?
.take_groups();

// we keep a local copy, as we are reordering on next operation.
let dt_local = dt.datetime().unwrap().clone();

// make sure that the output order is correct
dt = unsafe { dt.agg_list(&groups).explode().unwrap() };

// continue determining the rolling indexes.
let groups = groups.into_idx();

let groupsidx = POOL.install(|| {
groups
.par_iter()
.flat_map(|base_g| {
let dt = unsafe { dt.take_unchecked(base_g.1.into()) };
let dt = unsafe { dt_local.take_unchecked(base_g.1.into()) };
let vals = dt.downcast_iter().next().unwrap();
let ts = vals.values().as_slice();
let sub_groups = groupby_values(
Expand All @@ -374,13 +382,11 @@ impl Wrap<&DataFrame> {
})
.collect()
});
let mut groups = GroupsProxy::Idx(groupsidx);
groups.sort();
groups
GroupsProxy::Idx(groupsidx)
};
let dt = dt.cast(time_type).unwrap();

// the ordering has changed due to the groupby
// // the ordering has changed due to the groupby
if !by.is_empty() {
unsafe {
for key in by.iter_mut() {
Expand Down
16 changes: 9 additions & 7 deletions py-polars/tests/test_datelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,17 +1065,17 @@ def test_groupby_rolling_by_ordering() -> None:
).to_dict(
False
) == {
"key": ["A", "A", "B", "B", "A", "B", "A"],
"key": ["A", "A", "A", "A", "B", "B", "B"],
"dt": [
datetime(2022, 1, 1, 0, 1),
datetime(2022, 1, 1, 0, 2),
datetime(2022, 1, 1, 0, 5),
datetime(2022, 1, 1, 0, 7),
datetime(2022, 1, 1, 0, 3),
datetime(2022, 1, 1, 0, 4),
datetime(2022, 1, 1, 0, 5),
datetime(2022, 1, 1, 0, 6),
datetime(2022, 1, 1, 0, 7),
],
"sum val": [2, 2, 2, 2, 1, 1, 1],
"sum val": [2, 2, 1, 1, 2, 2, 1],
}


Expand Down Expand Up @@ -1118,16 +1118,18 @@ def test_groupby_rolling_by_() -> None:
),
how="cross",
)
out = df.groupby_rolling(index_column="datetime", by="group", period="3d").agg(
[pl.count().alias("count")]
out = (
df.sort("datetime")
.groupby_rolling(index_column="datetime", by="group", period="3d")
.agg([pl.count().alias("count")])
)

expected = (
df.sort(["group", "datetime"])
.groupby_rolling(index_column="datetime", by="group", period="3d")
.agg([pl.count().alias("count")])
)
assert out.frame_equal(expected)
assert out.sort(["group", "datetime"]).frame_equal(expected)
assert out.to_dict(False) == {
"group": [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
"datetime": [
Expand Down

0 comments on commit d919f4a

Please sign in to comment.