Skip to content

Commit

Permalink
fix[rust]: overlapping groups should not explode date in groupby (#4637)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 31, 2022
1 parent 85aa94b commit 7c6820f
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 17 deletions.
11 changes: 11 additions & 0 deletions polars/polars-core/src/chunked_array/iterator/par/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,15 @@ impl ListChunked {
})
.flatten()
}

// Get an indexed parallel iterator over the [`Series`] in this [`ListChunked`].
pub fn par_iter_indexed(&mut self) -> impl IndexedParallelIterator<Item = Option<Series>> + '_ {
*self = self.rechunk();
let arr = self.downcast_iter().next().unwrap();

let dtype = self.inner_dtype();
(0..arr.len())
.into_par_iter()
.map(move |idx| unsafe { idx_to_array(idx, arr, &dtype) })
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ impl Executor for GroupByDynamicExec {
let mut df = self.input.execute(state)?;
df.as_single_chunk_par();
state.set_schema(self.input_schema.clone());

// if the periods are larger than the intervals,
// the groups overlap
if self.options.every < self.options.period {
state.flags |= StateFlags::OVERLAPPING_GROUPS
}

let keys = self
.keys
.iter()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ impl Executor for GroupByRollingExec {
}
};

// a rolling groupby has overlapping windows
state.flags |= StateFlags::OVERLAPPING_GROUPS;

let agg_columns = POOL.install(|| {
self.aggs
.par_iter()
Expand Down
93 changes: 81 additions & 12 deletions polars/polars-lazy/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;
use polars_core::frame::groupby::GroupsProxy;
use polars_core::series::unstable::UnstableSeries;
use polars_core::{prelude::*, POOL};
use rayon::prelude::*;

use crate::physical_plan::state::{ExecutionState, StateFlags};
use crate::prelude::*;
Expand Down Expand Up @@ -133,7 +134,12 @@ impl PhysicalExpr for BinaryExpr {
));
}

match (ac_l.agg_state(), ac_r.agg_state(), self.op) {
match (
ac_l.agg_state(),
ac_r.agg_state(),
self.op,
state.overlapping_groups(),
) {
// Some aggregations must return boolean masks that fit the group. That's why not all literals can take this path.
// only literals that are used in arithmetic
(
Expand All @@ -145,6 +151,7 @@ impl PhysicalExpr for BinaryExpr {
| Operator::Multiply
| Operator::Modulus
| Operator::TrueDivide,
_,
)
| (
AggState::Literal(lhs),
Expand All @@ -155,6 +162,7 @@ impl PhysicalExpr for BinaryExpr {
| Operator::Multiply
| Operator::Modulus
| Operator::TrueDivide,
_,
) => {
// we want to be able to mutate in place
// so we take the lhs to make sure that we drop
Expand All @@ -175,8 +183,8 @@ impl PhysicalExpr for BinaryExpr {
// the other is a literal value. In that case it is unlikely we want to expand this to the
// group sizes.
//
(AggState::AggregatedFlat(_), AggState::Literal(_), _)
| (AggState::Literal(_), AggState::AggregatedFlat(_), _) => {
(AggState::AggregatedFlat(_), AggState::Literal(_), _op, _overlapping_groups)
| (AggState::Literal(_), AggState::AggregatedFlat(_), _op, _overlapping_groups) => {
let l = ac_l.series().clone();
let r = ac_r.series().clone();

Expand All @@ -194,7 +202,7 @@ impl PhysicalExpr for BinaryExpr {
// if the groups_len == df.len we can just apply all flat.
// within an aggregation a `col().first() - lit(0)` must still produce a boolean array of group length,
// that's why a literal also takes this branch
(AggState::AggregatedFlat(s), AggState::NotAggregated(_), _)
(AggState::AggregatedFlat(s), AggState::NotAggregated(_), _op, _overlapping_groups)
if s.len() != df.height() =>
{
// this is a flat series of len eq to group tuples
Expand Down Expand Up @@ -245,7 +253,8 @@ impl PhysicalExpr for BinaryExpr {
(
AggState::AggregatedList(_) | AggState::NotAggregated(_),
AggState::AggregatedFlat(s),
_,
_op,
_overlapping_groups,
) if s.len() != df.height() => {
// this is now a list
let l = ac_l.aggregated_arity_operation();
Expand Down Expand Up @@ -297,9 +306,27 @@ impl PhysicalExpr for BinaryExpr {
}
Ok(ac_l)
}
(AggState::AggregatedList(_), AggState::NotAggregated(_) | AggState::Literal(_), _)
| (AggState::NotAggregated(_) | AggState::Literal(_), AggState::AggregatedList(_), _) =>
{

// # Align data in sort order and apply flattened.
// 1 we sort/aggregate by groups
// 2 then we flatten/explode and do the binary operation.
// 3 then we use the original groups length to restore the groups
//
// Overlapping groups may not take this branch.
// when groups overlap, step 2 creates more values than rows
// and the original group lengths will be incorrect
(
AggState::AggregatedList(_),
AggState::NotAggregated(_) | AggState::Literal(_),
_op,
false,
)
| (
AggState::NotAggregated(_) | AggState::Literal(_),
AggState::AggregatedList(_),
_op,
false,
) => {
ac_l.sort_by_groups();
ac_r.sort_by_groups();

Expand All @@ -318,8 +345,11 @@ impl PhysicalExpr for BinaryExpr {
ac_l.with_series(out, false);
Ok(ac_l)
}
// flatten the Series and apply the operators
(AggState::AggregatedList(_), AggState::AggregatedList(_), _) => {
// # Flatten the Series and apply the operators.
//
// Overlapping groups may not take this branch.
// the explode call would create more data and is expensive
(AggState::AggregatedList(_), AggState::AggregatedList(_), _op, false) => {
let lhs = ac_l.flat_naive().as_ref().clone();
let rhs = ac_r.flat_naive().as_ref().clone();

Expand All @@ -334,9 +364,9 @@ impl PhysicalExpr for BinaryExpr {
ac_l.with_update_groups(UpdateGroups::WithGroupsLen);
Ok(ac_l)
}
// Both are or a flat series
// Both are or a flat series (if groups do not overlap)
// so we can flatten the Series and apply the operators
_ => {
(_l, _r, _op, false) => {
// Check if the group state of `ac_a` differs from the original `GroupTuples`.
// If this is the case we might need to align the groups. But only if `ac_b` is not a
// `Literal` as literals don't have any groups, the changed group order does not matter
Expand Down Expand Up @@ -380,6 +410,45 @@ impl PhysicalExpr for BinaryExpr {
Ok(ac_l)
}
}
// overlapping groups, we iterate the separate groups, so that we don't have to explode
// If both sides are aggregated to a list, we can apply in parallel
(AggState::AggregatedList(_), AggState::AggregatedList(_), _op, true) => {
let l = ac_l.aggregated();
let r = ac_r.aggregated();

let mut l = l.list()?.clone();
let mut r = r.list()?.clone();

let mut out = POOL.install(|| {
l.par_iter_indexed()
.zip(r.par_iter_indexed())
.map(|(opt_l, opt_r)| match (opt_l, opt_r) {
(Some(l), Some(r)) => apply_operator(&l, &r, self.op).map(Some),
_ => Ok(None),
})
.collect::<Result<ListChunked>>()
})?;

out.rename(ac_l.series().name());
ac_l.with_series(out.into_series(), true);
Ok(ac_l)
}
// overlapping groups, we iterate the separate groups, so that we don't have to explode
(_l, _r, _op, true) => {
let mut out = ac_l
.iter_groups()
.zip(ac_r.iter_groups())
.map(|(opt_l, opt_r)| match (opt_l, opt_r) {
(Some(l), Some(r)) => {
apply_operator(l.as_ref(), r.as_ref(), self.op).map(Some)
}
_ => Ok(None),
})
.collect::<Result<ListChunked>>()?;
out.rename(ac_l.series().name());
ac_l.with_series(out.into_series(), true);
Ok(ac_l)
}
}
}

Expand Down
5 changes: 4 additions & 1 deletion polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ impl<'a> AggregationContext<'a> {
})
}
// sliced groups are already in correct order
GroupsProxy::Slice { .. } => {}
GroupsProxy::Slice { groups, .. } => {
dbg!(groups);
dbg!("hree");
}
}
self.update_groups = UpdateGroups::No;
}
Expand Down
5 changes: 2 additions & 3 deletions polars/polars-lazy/src/physical_plan/expressions/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,8 @@ impl WindowExpr {
if let GroupsProxy::Idx(g) = gb.get_groups() {
debug_assert!(g.is_sorted())
}
else {
debug_assert!(false)
}
// GroupsProxy::Slice is always sorted

// Note that group columns must be sorted for this to make sense!!!
Ok(MapStrategy::Explode)
} else {
Expand Down
16 changes: 15 additions & 1 deletion polars/polars-lazy/src/physical_plan/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@ pub type GroupsProxyCache = Arc<Mutex<PlHashMap<String, GroupsProxy>>>;

bitflags! {
pub(super) struct StateFlags: u8 {
/// More verbose logging
const VERBOSE = 0x01;
/// Indicates that window expression's [`GroupTuples`] may be cached.
const CACHE_WINDOW_EXPR = 0x02;
const FILTER_NODE = 0x03;
/// Indicates that a groupby operations groups may overlap.
/// If this is the case, an `explode` will yield more values than rows in original `df`,
/// this breaks some assumptions
const OVERLAPPING_GROUPS = 0x03;
}
}

Expand Down Expand Up @@ -152,10 +157,19 @@ impl ExecutionState {
lock.clear();
}

/// Indicates that window expression's [`GroupTuples`] may be cached.
pub(super) fn cache_window(&self) -> bool {
self.flags.contains(StateFlags::CACHE_WINDOW_EXPR)
}

/// Indicates that a groupby operations groups may overlap.
/// If this is the case, an `explode` will yield more values than rows in original `df`,
/// this breaks some assumptions
pub(super) fn overlapping_groups(&self) -> bool {
self.flags.contains(StateFlags::OVERLAPPING_GROUPS)
}

/// More verbose logging
pub(super) fn verbose(&self) -> bool {
self.flags.contains(StateFlags::VERBOSE)
}
Expand Down
13 changes: 13 additions & 0 deletions polars/polars-time/src/windows/duration.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::cmp::Ordering;
use std::ops::Mul;

use chrono::{Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike};
Expand Down Expand Up @@ -30,6 +31,18 @@ pub struct Duration {
pub parsed_int: bool,
}

impl PartialOrd<Self> for Duration {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.duration_ns().partial_cmp(&other.duration_ns())
}
}

impl Ord for Duration {
fn cmp(&self, other: &Self) -> Ordering {
self.duration_ns().cmp(&other.duration_ns())
}
}

impl Duration {
/// Create a new integer size `Duration`
pub fn new(fixed_slots: i64) -> Self {
Expand Down
35 changes: 35 additions & 0 deletions py-polars/tests/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,38 @@ def test_groupby_dynamic_slice_pushdown() -> None:
"a": [0, 2],
"c": [1, 3],
}


def test_overlapping_groups_4628() -> None:
df = pl.DataFrame(
{
"index": [1, 2, 3, 4, 5, 6],
"val": [10, 20, 40, 70, 110, 160],
}
)
assert (
df.groupby_rolling(index_column="index", period="3i",).agg(
[
pl.col("val").diff(n=1).alias("val.diff"),
(pl.col("val") - pl.col("val").shift(1)).alias("val - val.shift"),
]
)
).to_dict(False) == {
"index": [1, 2, 3, 4, 5, 6],
"val.diff": [
[None],
[None, 10],
[None, 10, 20],
[None, 20, 30],
[None, 30, 40],
[None, 40, 50],
],
"val - val.shift": [
[None],
[None, 10],
[None, 10, 20],
[None, 20, 30],
[None, 30, 40],
[None, 40, 50],
],
}

0 comments on commit 7c6820f

Please sign in to comment.