Skip to content

Commit

Permalink
fix: Don't count nulls in streaming count agg (#15051)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Mar 14, 2024
1 parent 33c9c84 commit fd9eba2
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 12 deletions.
Expand Up @@ -131,7 +131,7 @@ where
AExpr::Len => (
IDX_DTYPE,
Arc::new(Len {}),
AggregateFunction::Count(CountAgg::new()),
AggregateFunction::Len(CountAgg::new()),
),
AExpr::Agg(agg) => match agg {
AAggExpr::Min { input, .. } => {
Expand Down
Expand Up @@ -7,26 +7,28 @@ use polars_utils::unwrap::UnwrapUncheckedRelease;
use super::*;
use crate::operators::IdxSize;

pub(crate) struct CountAgg {
pub(crate) struct CountAgg<const INCLUDE_NULL: bool> {
count: IdxSize,
}

impl CountAgg {
impl<const INCLUDE_NULL: bool> CountAgg<INCLUDE_NULL> {
pub(crate) fn new() -> Self {
CountAgg { count: 0 }
}
fn incr(&mut self) {
self.count += 1;
}
}

impl AggregateFn for CountAgg {
impl<const INCLUDE_NULL: bool> AggregateFn for CountAgg<INCLUDE_NULL> {
fn has_physical_agg(&self) -> bool {
false
}

fn pre_agg(&mut self, _chunk_idx: IdxSize, _item: &mut dyn ExactSizeIterator<Item = AnyValue>) {
self.incr();
fn pre_agg(&mut self, _chunk_idx: IdxSize, item: &mut dyn ExactSizeIterator<Item = AnyValue>) {
let item = unsafe { item.next().unwrap_unchecked_release() };
if INCLUDE_NULL {
self.count += 1;
} else {
self.count += !matches!(item, AnyValue::Null) as IdxSize;
}
}
fn pre_agg_ordered(
&mut self,
Expand Down
Expand Up @@ -46,7 +46,8 @@ pub(crate) trait AggregateFn: Send + Sync {
pub(crate) enum AggregateFunction {
First(FirstAgg),
Last(LastAgg),
Count(CountAgg),
Count(CountAgg<false>),
Len(CountAgg<true>),
SumF32(SumAgg<f32>),
SumF64(SumAgg<f64>),
SumU32(SumAgg<u32>),
Expand Down Expand Up @@ -83,6 +84,7 @@ impl AggregateFunction {
MeanF32(_) => MeanF32(MeanAgg::new()),
MeanF64(_) => MeanF64(MeanAgg::new()),
Count(_) => Count(CountAgg::new()),
Len(_) => Len(CountAgg::new()),
Null(a) => Null(a.clone()),
MinMaxF32(inner) => MinMaxF32(inner.split()),
MinMaxF64(inner) => MinMaxF64(inner.split()),
Expand Down
Expand Up @@ -131,11 +131,11 @@ impl<const FIXED: bool> AggHashTable<FIXED> {
pub(super) unsafe fn insert(
&mut self,
hash: u64,
row: &[u8],
key: &[u8],
agg_iters: &mut [SeriesPhysIter],
chunk_index: IdxSize,
) -> bool {
let agg_idx = match self.insert_key(hash, row) {
let agg_idx = match self.insert_key(hash, key) {
// overflow
None => return true,
Some(agg_idx) => agg_idx,
Expand Down
7 changes: 7 additions & 0 deletions py-polars/tests/unit/streaming/test_streaming_group_by.py
Expand Up @@ -446,3 +446,10 @@ def test_group_by_multiple_keys_one_literal(streaming: bool) -> None:
.to_dict(as_series=False)
== expected
)


def test_streaming_group_null_count() -> None:
df = pl.DataFrame({"g": [1] * 6, "a": ["yes", None] * 3}).lazy()
assert df.group_by("g").agg(pl.col("a").count()).collect(streaming=True).to_dict(
as_series=False
) == {"g": [1], "a": [3]}

0 comments on commit fd9eba2

Please sign in to comment.