Skip to content

Commit

Permalink
fix[rust]: fix groupby_dynamic flat agg type (#4822)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 11, 2022
1 parent 10627b8 commit abda9fb
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 49 deletions.
24 changes: 23 additions & 1 deletion polars/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,13 @@ impl LazyFrame {
}
}

/// Create rolling groups based on a time column.
///
/// Also works for index values of type Int32 or Int64.
///
/// Different from a [`dynamic_groupby`] the windows are now determined by the
/// individual values and are not of constant intervals. For constant intervals use
/// *groupby_dynamic*
pub fn groupby_rolling<E: AsRef<[Expr]>>(
self,
by: E,
Expand All @@ -853,6 +860,21 @@ impl LazyFrame {
}
}

/// Group based on a time value (or index value of type Int32, Int64).
///
/// Time windows are calculated and rows are assigned to windows. Different from a
/// normal groupby is that a row can be member of multiple groups. The time/index
/// window could be seen as a rolling window, with a window size determined by
/// dates/times/values instead of slots in the DataFrame.
///
/// A window is defined by:
///
/// - every: interval of the window
/// - period: length of the window
/// - offset: offset of the window
///
/// The `by` argument should be empty `[]` if you don't want to combine this
/// with a ordinary groupby on these keys.
pub fn groupby_dynamic<E: AsRef<[Expr]>>(
self,
by: E,
Expand All @@ -869,7 +891,7 @@ impl LazyFrame {
}
}

/// Similar to groupby, but order of the DataFrame is maintained.
/// Similar to [`groupby`], but order of the DataFrame is maintained.
pub fn groupby_stable<E: AsRef<[IE]>, IE: Into<Expr> + Clone>(self, by: E) -> LazyGroupBy {
let keys = by
.as_ref()
Expand Down
53 changes: 7 additions & 46 deletions polars/polars-lazy/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,31 +138,13 @@ impl PhysicalExpr for BinaryExpr {
match (
ac_l.agg_state(),
ac_r.agg_state(),
self.op,
state.overlapping_groups(),
) {
// Some aggregations must return boolean masks that fit the group. That's why not all literals can take this path.
// only literals that are used in arithmetic
(
AggState::AggregatedFlat(lhs),
AggState::Literal(rhs),
Operator::Plus
| Operator::Minus
| Operator::Divide
| Operator::Multiply
| Operator::Modulus
| Operator::TrueDivide,
_,
)
| (
AggState::Literal(lhs),
AggState::AggregatedFlat(rhs),
Operator::Plus
| Operator::Minus
| Operator::Divide
| Operator::Multiply
| Operator::Modulus
| Operator::TrueDivide,
AggState::AggregatedFlat(lhs) | AggState::Literal(lhs),
AggState::AggregatedFlat(rhs) | AggState::Literal(rhs),
_,
) => {
// we want to be able to mutate in place
Expand All @@ -181,29 +163,11 @@ impl PhysicalExpr for BinaryExpr {
Ok(ac_l)
}
// One of the two exprs is aggregated with flat aggregation, e.g. `e.min(), e.max(), e.first()`
// the other is a literal value. In that case it is unlikely we want to expand this to the
// group sizes.
//
(AggState::AggregatedFlat(_), AggState::Literal(_), _op, _overlapping_groups)
| (AggState::Literal(_), AggState::AggregatedFlat(_), _op, _overlapping_groups) => {
let l = ac_l.series().clone();
let r = ac_r.series().clone();

// drop lhs so that we might operate in place
{
let _ = ac_l.take();
}
let out = apply_operator_owned(l, r, self.op)?;

ac_l.with_series(out, true);
Ok(ac_l)
}
// One of the two exprs is aggregated with flat aggregation, e.g. `e.min(), e.max(), e.first()`

// if the groups_len == df.len we can just apply all flat.
// within an aggregation a `col().first() - lit(0)` must still produce a boolean array of group length,
// that's why a literal also takes this branch
(AggState::AggregatedFlat(s), AggState::NotAggregated(_), _op, _overlapping_groups)
(AggState::AggregatedFlat(s), AggState::NotAggregated(_), _overlapping_groups)
if s.len() != df.height() =>
{
// this is a flat series of len eq to group tuples
Expand Down Expand Up @@ -254,7 +218,6 @@ impl PhysicalExpr for BinaryExpr {
(
AggState::AggregatedList(_) | AggState::NotAggregated(_),
AggState::AggregatedFlat(s),
_op,
_overlapping_groups,
) if s.len() != df.height() => {
// this is now a list
Expand Down Expand Up @@ -319,13 +282,11 @@ impl PhysicalExpr for BinaryExpr {
(
AggState::AggregatedList(_),
AggState::NotAggregated(_) | AggState::Literal(_),
_op,
false,
)
| (
AggState::NotAggregated(_) | AggState::Literal(_),
AggState::AggregatedList(_),
_op,
false,
) => {
ac_l.sort_by_groups();
Expand All @@ -350,7 +311,7 @@ impl PhysicalExpr for BinaryExpr {
//
// Overlapping groups may not take this branch.
// the explode call would create more data and is expensive
(AggState::AggregatedList(_), AggState::AggregatedList(_), _op, false) => {
(AggState::AggregatedList(_), AggState::AggregatedList(_), false) => {
let lhs = ac_l.flat_naive().as_ref().clone();
let rhs = ac_r.flat_naive().as_ref().clone();

Expand All @@ -367,7 +328,7 @@ impl PhysicalExpr for BinaryExpr {
}
// Both are or a flat series (if groups do not overlap)
// so we can flatten the Series and apply the operators
(_l, _r, _op, false) => {
(_l, _r, false) => {
// Check if the group state of `ac_a` differs from the original `GroupTuples`.
// If this is the case we might need to align the groups. But only if `ac_b` is not a
// `Literal` as literals don't have any groups, the changed group order does not matter
Expand Down Expand Up @@ -413,7 +374,7 @@ impl PhysicalExpr for BinaryExpr {
}
// overlapping groups, we iterate the separate groups, so that we don't have to explode
// If both sides are aggregated to a list, we can apply in parallel
(AggState::AggregatedList(_), AggState::AggregatedList(_), _op, true) => {
(AggState::AggregatedList(_), AggState::AggregatedList(_), true) => {
let l = ac_l.aggregated();
let r = ac_r.aggregated();

Expand All @@ -435,7 +396,7 @@ impl PhysicalExpr for BinaryExpr {
Ok(ac_l)
}
// overlapping groups, we iterate the separate groups, so that we don't have to explode
(_l, _r, _op, true) => {
(_l, _r, true) => {
let mut out = ac_l
.iter_groups()
.zip(ac_r.iter_groups())
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3105,7 +3105,7 @@ def groupby_rolling(
Also works for index values of type Int32 or Int64.
Different from a rolling groupby the windows are now determined by the
Different from a ``dynamic_groupby`` the windows are now determined by the
individual values and are not of constant intervals. For constant intervals use
*groupby_dynamic*
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,7 @@ def groupby_rolling(
Also works for index values of type Int32 or Int64.
Different from a rolling groupby the windows are now determined by the
Different from a ``dynamic_groupby`` the windows are now determined by the
individual values and are not of constant intervals. For constant intervals use
*groupby_dynamic*
Expand Down
17 changes: 17 additions & 0 deletions py-polars/tests/unit/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,20 @@ def test_unique_order() -> None:
"row_nr": [0, 1],
"a": [1, 2],
}


def test_groupby_dynamic_flat_agg_4814() -> None:
df = pl.DataFrame({"a": [1, 2, 2], "b": [1, 8, 12]})

assert df.groupby_dynamic("a", every="1i", period="2i").agg(
[
(pl.col("b").sum() / pl.col("a").sum()).alias("sum_ratio_1"),
(pl.col("b").last() / pl.col("a").last()).alias("last_ratio_1"),
(pl.col("b") / pl.col("a")).last().alias("last_ratio_2"),
]
).to_dict(False) == {
"a": [1, 2],
"sum_ratio_1": [4.2, 5.0],
"last_ratio_1": [6.0, 6.0],
"last_ratio_2": [6.0, 6.0],
}

0 comments on commit abda9fb

Please sign in to comment.