Skip to content

Commit

Permalink
fix(rust, python): don't block non matching groups in binary expressi…
Browse files Browse the repository at this point in the history
…on (#5273)
  • Loading branch information
ritchie46 committed Oct 20, 2022
1 parent 28b55f3 commit 88c850f
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 33 deletions.
6 changes: 0 additions & 6 deletions polars/polars-lazy/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use polars_core::series::unstable::UnstableSeries;
use polars_core::POOL;
use rayon::prelude::*;

use crate::physical_plan::expression_err;
use crate::physical_plan::state::{ExecutionState, StateFlags};
use crate::prelude::*;

Expand Down Expand Up @@ -124,11 +123,6 @@ impl PhysicalExpr for BinaryExpr {
let mut ac_l = result_a?;
let mut ac_r = result_b?;

if !ac_l.can_combine(&ac_r) {
let msg = "Cannot combine this binary expression, the groups do not match.";
return Err(expression_err!(msg, self.expr, InvalidOperation));
}

match (
ac_l.agg_state(),
ac_r.agg_state(),
Expand Down
27 changes: 0 additions & 27 deletions polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,29 +157,6 @@ impl<'a> AggregationContext<'a> {
&self.groups
}

/// Check if this contexts group tuples can be combined with that of other.
pub(crate) fn can_combine(&self, other: &AggregationContext) -> bool {
match (
&self.groups,
self.sorted,
self.is_original_len(),
&other.groups,
other.sorted,
other.original_len,
) {
(Cow::Borrowed(_), _, _, Cow::Borrowed(_), _, _) => true,
(Cow::Owned(_), _, _, Cow::Borrowed(_), _, _) => true,
(Cow::Borrowed(_), _, _, Cow::Owned(_), _, _) => true,
(Cow::Owned(_), true, true, Cow::Owned(_), true, true) => true,
(Cow::Owned(_), true, false, Cow::Owned(_), true, true) => false,
(Cow::Owned(_), true, true, Cow::Owned(_), true, false) => false,
(Cow::Owned(_), true, _, Cow::Owned(_), true, _) => {
self.groups.len() == other.groups.len()
}
_ => false,
}
}

pub(crate) fn series(&self) -> &Series {
match &self.state {
AggState::NotAggregated(s)
Expand Down Expand Up @@ -253,10 +230,6 @@ impl<'a> AggregationContext<'a> {
}
}

pub(crate) fn is_original_len(&self) -> bool {
self.original_len
}

pub(crate) fn set_original_len(&mut self, original_len: bool) -> &mut Self {
self.original_len = original_len;
self
Expand Down
7 changes: 7 additions & 0 deletions py-polars/tests/unit/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,10 @@ def test_groupby_dynamic_overlapping_groups_flat_apply_multiple_5038() -> None:
).collect().sum().to_dict(False) == pytest.approx(
{"a": [None], "corr": [6.988674024215477]}
)


def test_take_in_groupby() -> None:
df = pl.DataFrame({"group": [1, 1, 1, 2, 2, 2], "values": [10, 200, 3, 40, 500, 6]})
assert df.groupby("group").agg(
pl.col("values").take(1) - pl.col("values").take(2)
).sort("group").to_dict(False) == {"group": [1, 2], "values": [197, 494]}

0 comments on commit 88c850f

Please sign in to comment.