Skip to content

Commit

Permalink
fix bug in binary aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Feb 1, 2022
1 parent 13cb614 commit 61146e8
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 38 deletions.
23 changes: 14 additions & 9 deletions polars/polars-arrow/src/kernels/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,17 @@ pub unsafe fn take_primitive_unchecked<T: NativeType>(
arr: &PrimitiveArray<T>,
indices: &UInt32Array,
) -> Arc<PrimitiveArray<T>> {
let array_values = arr.values();
let index_values = indices.values();
let array_values = arr.values().as_slice();
let index_values = indices.values().as_slice();
let validity_values = arr.validity().expect("should have nulls");

// first take the values, these are always needed
let values: Vec<T> = index_values
.iter()
.map(|idx| *array_values.get_unchecked(*idx as usize))
.map(|idx| {
debug_assert!((*idx as usize) < array_values.len());
*array_values.get_unchecked(*idx as usize)
})
.collect_trusted();

// the validity buffer we will fill with all valid. And we unset the ones that are null
Expand Down Expand Up @@ -92,9 +95,10 @@ pub unsafe fn take_no_null_primitive<T: NativeType>(
let array_values = arr.values().as_slice();
let index_values = indices.values().as_slice();

let iter = index_values
.iter()
.map(|idx| *array_values.get_unchecked(*idx as usize));
let iter = index_values.iter().map(|idx| {
debug_assert!((*idx as usize) < array_values.len());
*array_values.get_unchecked(*idx as usize)
});

let values = Buffer::from_trusted_len_iter(iter);
let validity = indices.validity().cloned();
Expand All @@ -121,9 +125,10 @@ pub unsafe fn take_no_null_primitive_iter_unchecked<
debug_assert!(!arr.has_validity());
let array_values = arr.values().as_slice();

let iter = indices
.into_iter()
.map(|idx| *array_values.get_unchecked(idx));
let iter = indices.into_iter().map(|idx| {
debug_assert!((idx as usize) < array_values.len());
*array_values.get_unchecked(idx)
});

let values = Buffer::from_trusted_len_iter_unchecked(iter);
Arc::new(PrimitiveArray::from_data(T::PRIMITIVE.into(), values, None))
Expand Down
7 changes: 4 additions & 3 deletions polars/polars-core/src/frame/groupby/aggregations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -778,9 +778,10 @@ where
// Safety:
// group tuples are in bounds
unsafe {
list_values.extend(
idx.iter().map(|idx| *values.get_unchecked(*idx as usize)),
);
list_values.extend(idx.iter().map(|idx| {
debug_assert!((*idx as usize) < values.len());
*values.get_unchecked(*idx as usize)
}));
// Safety:
// we know that offsets has allocated enough slots
offsets.push_unchecked(length_so_far);
Expand Down
18 changes: 18 additions & 0 deletions polars/polars-lazy/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,23 @@ impl PhysicalExpr for BinaryExpr {
ac_l.with_series(ca.into_series(), true);
Ok(ac_l)
}
(AggState::AggregatedList(_), AggState::NotAggregated(_))
| (AggState::NotAggregated(_), AggState::AggregatedList(_)) => {
ac_l.sort_by_groups();
ac_r.sort_by_groups();

let out = apply_operator(
ac_l.flat_naive().as_ref(),
ac_r.flat_naive().as_ref(),
self.op,
)?;

// we flattened the series, so that sorts by group
ac_l.with_update_groups(UpdateGroups::WithGroupsLen);
ac_l.with_series(out, false);
Ok(ac_l)
}

// Both are or a flat series or aggregated into a list
// so we can flatten the Series and apply the operators
_ => {
Expand All @@ -197,6 +214,7 @@ impl PhysicalExpr for BinaryExpr {
ac_r.flat_naive().as_ref(),
self.op,
)?;

ac_l.combine_groups(ac_r).with_series(out, false);
Ok(ac_l)
}
Expand Down
37 changes: 37 additions & 0 deletions polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,23 @@ pub(crate) enum AggState {
None,
}

impl AggState {
// Literal series are not safe to aggregate
fn safe_to_agg(&self, groups: &GroupsProxy) -> bool {
match self {
AggState::NotAggregated(s) => {
!(s.len() == 1
// or more then one group
&& (groups.len() > 1
// or single groups with more than one index
|| !groups.is_empty()
&& groups.get(0).len() > 1))
}
_ => true,
}
}
}

// lazy update strategy
#[cfg_attr(debug_assertions, derive(Debug))]
pub(crate) enum UpdateGroups {
Expand Down Expand Up @@ -264,6 +281,26 @@ impl<'a> AggregationContext<'a> {
self
}

/// In a binary expression one state can be aggregated and the other not.
/// If both would be flattened naively one would be sorted and the other not.
/// Calling this function will ensure both are sortened. This will be a no-op
/// if already aggregated.
pub(crate) fn sort_by_groups(&mut self) {
match &self.series {
AggState::NotAggregated(s) => {
// We should not aggregate literals!!
if self.series.safe_to_agg(&self.groups) {
let agg = s.agg_list(&self.groups).unwrap();
self.update_groups = UpdateGroups::WithGroupsLen;
self.series = AggState::AggregatedList(agg);
}
}
AggState::AggregatedFlat(_) => {}
AggState::AggregatedList(_) => {}
AggState::None => {}
}
}

/// # Arguments
/// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its
/// the columns dtype)
Expand Down
26 changes: 26 additions & 0 deletions polars/polars-lazy/src/tests/aggregations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,29 @@ fn test_binary_agg_context_2() -> Result<()> {

Ok(())
}

#[test]
fn test_shift_elementwise_issue_2509() -> Result<()> {
let df = df![
"x"=> [0, 0, 0, 1, 1, 1, 2, 2, 2],
"y"=> [0, 10, 20, 0, 10, 20, 0, 10, 20]
]?;
let out = df
.lazy()
// Don't use maintain order here! That hides the bug
.groupby([col("x")])
.agg(&[(col("y").shift(-1) + col("x")).list().alias("sum")])
.sort("x", false)
.collect()?;

let out = out.explode(["sum"])?;
let out = out.column("sum")?;
assert_eq!(out.get(0), AnyValue::Int32(10));
assert_eq!(out.get(1), AnyValue::Int32(20));
assert_eq!(out.get(2), AnyValue::Null);
assert_eq!(out.get(3), AnyValue::Int32(11));
assert_eq!(out.get(4), AnyValue::Int32(21));
assert_eq!(out.get(5), AnyValue::Null);

Ok(())
}
15 changes: 15 additions & 0 deletions polars/polars-lazy/src/tests/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,18 @@ fn test_single_thread_when_then_otherwise_categorical() -> Result<()> {
assert!(s.contains("same"));
Ok(())
}

#[test]
fn test_lazy_ternary() {
let df = get_df()
.lazy()
.with_column(
when(col("sepal.length").lt(lit(5.0)))
.then(lit(10))
.otherwise(lit(1))
.alias("new"),
)
.collect()
.unwrap();
assert_eq!(Some(43), df.column("new").unwrap().sum::<i32>());
}
8 changes: 8 additions & 0 deletions polars/polars-lazy/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ mod predicate_queries;
mod projection_queries;
mod queries;

fn load_df() -> DataFrame {
df!("a" => &[1, 2, 3, 4, 5],
"b" => &["a", "a", "b", "c", "c"],
"c" => &[1, 2, 3, 4, 5]
)
.unwrap()
}

use optimization_checks::*;
use std::sync::Mutex;

Expand Down
9 changes: 7 additions & 2 deletions polars/polars-lazy/src/tests/optimization_checks.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use super::*;
use crate::tests::queries::load_df;

fn get_arenas() -> (Arena<AExpr>, Arena<ALogicalPlan>) {
let expr_arena = Arena::with_capacity(16);
Expand Down Expand Up @@ -98,7 +97,13 @@ fn test_no_left_join_pass() -> Result<()> {
.filter(col("bar").eq(lit(5i32)))
.collect()?;

dbg!(out);
let expected = df![
"foo" => ["abc", "def"],
"idx1" => [0, 0],
"bar" => [5, 5],
]?;

assert!(out.frame_equal(&expected));
Ok(())
}

Expand Down
24 changes: 0 additions & 24 deletions polars/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,6 @@ use super::*;
use polars_arrow::prelude::QuantileInterpolOptions;
use polars_core::series::ops::NullBehavior;

pub(crate) fn load_df() -> DataFrame {
df!("a" => &[1, 2, 3, 4, 5],
"b" => &["a", "a", "b", "c", "c"],
"c" => &[1, 2, 3, 4, 5]
)
.unwrap()
}

#[test]
fn test_lazy_ternary() {
let df = get_df()
.lazy()
.with_column(
when(col("sepal.length").lt(lit(5.0)))
.then(lit(10))
.otherwise(lit(1))
.alias("new"),
)
.collect()
.unwrap();
assert_eq!(Some(43), df.column("new").unwrap().sum::<i32>());
}

#[test]
fn test_lazy_with_column() {
let df = get_df()
Expand Down Expand Up @@ -91,7 +68,6 @@ fn test_lazy_melt() {
.collect()
.unwrap();
assert_eq!(out.shape(), (7, 3));
dbg!(out);
}

#[test]
Expand Down

0 comments on commit 61146e8

Please sign in to comment.