Skip to content

Commit

Permalink
fix null aggregation edge case (#3536)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 30, 2022
1 parent ada3a0e commit 77ae80a
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 5 deletions.
10 changes: 7 additions & 3 deletions polars/polars-arrow/src/kernels/take_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub unsafe fn take_agg_no_null_primitive_iter_unchecked<

/// Take kernel for single chunk and an iterator as index.
/// # Safety
/// caller must enure iterators indexes are in bounds
/// caller must ensure iterators indexes are in bounds
#[inline]
pub unsafe fn take_agg_primitive_iter_unchecked<
T: NativeType,
Expand All @@ -44,18 +44,21 @@ pub unsafe fn take_agg_primitive_iter_unchecked<
indices: I,
f: F,
init: T,
len: IdxSize,
) -> Option<T> {
let array_values = arr.values().as_slice();
let validity = arr.validity().expect("null buffer should be there");
let mut null_count = 0 as IdxSize;

let out = indices.into_iter().fold(init, |acc, idx| {
if validity.get_bit_unchecked(idx) {
f(acc, *array_values.get_unchecked(idx))
} else {
null_count += 1;
acc
}
});
if out == init {
if null_count == len {
None
} else {
Some(out)
Expand All @@ -76,6 +79,7 @@ pub unsafe fn take_agg_primitive_iter_unchecked_count_nulls<
indices: I,
f: F,
init: TOut,
len: IdxSize,
) -> Option<(TOut, IdxSize)> {
let array_values = arr.values().as_slice();
let validity = arr.validity().expect("null buffer should be there");
Expand All @@ -92,7 +96,7 @@ pub unsafe fn take_agg_primitive_iter_unchecked_count_nulls<
acc
}
});
if out == init {
if null_count == len {
None
} else {
Some((out, null_count))
Expand Down
13 changes: 11 additions & 2 deletions polars/polars-core/src/frame/groupby/aggregations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ where
idx.iter().map(|i| *i as usize),
|a, b| if a < b { a } else { b },
T::Native::max_value(),
idx.len() as IdxSize,
)
},
_ => {
Expand Down Expand Up @@ -469,6 +470,7 @@ where
idx.iter().map(|i| *i as usize),
|a, b| if a > b { a } else { b },
T::Native::min_value(),
idx.len() as IdxSize,
)
},
_ => {
Expand Down Expand Up @@ -547,12 +549,17 @@ where
)
}),
(_, 1) => unsafe {
take_agg_primitive_iter_unchecked::<T::Native, _, _>(
let out = take_agg_primitive_iter_unchecked::<T::Native, _, _>(
self.downcast_iter().next().unwrap(),
idx.iter().map(|i| *i as usize),
|a, b| a + b,
T::Native::zero(),
)
idx.len() as IdxSize,
);
if out.is_none() {
dbg!(idx, self);
};
out
},
_ => {
let take = unsafe { self.take_unchecked(idx.into()) };
Expand Down Expand Up @@ -641,6 +648,7 @@ where
idx.iter().map(|i| *i as usize),
|a, b| a + b,
T::Native::zero(),
idx.len() as IdxSize,
)
}
.map(|(sum, null_count)| {
Expand Down Expand Up @@ -907,6 +915,7 @@ where
idx.iter().map(|i| *i as usize),
|a, b| a + b,
0.0,
idx.len() as IdxSize,
)
}
.map(|(sum, null_count)| {
Expand Down
27 changes: 27 additions & 0 deletions py-polars/tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,30 @@ def test_median_on_shifted_col_3522() -> None:
)
diffs = df.select(pl.col("foo").diff().dt.seconds())
assert diffs.select(pl.col("foo").median()).to_series()[0] == 36828.5


def test_groupby_agg_equals_zero_3535() -> None:
# setup test frame
df = pl.DataFrame(
data=[
# note: the 'bb'-keyed values should clearly sum to 0
("aa", 10, None),
("bb", -10, 0.5),
("bb", 10, -0.5),
("cc", -99, 10.5),
("cc", None, 0.0),
],
columns=[
("key", pl.Utf8),
("val1", pl.Int16),
("val2", pl.Float32),
],
)
# group by the key, aggregating the two numeric cols
assert df.groupby(pl.col("key"), maintain_order=True).agg(
[pl.col("val1").sum(), pl.col("val2").sum()]
).to_dict(False) == {
"key": ["aa", "bb", "cc"],
"val1": [10, 0, -99],
"val2": [None, 0.0, 10.5],
}

0 comments on commit 77ae80a

Please sign in to comment.