Skip to content

Commit

Permalink
fix overflow in agg_mean (#3183)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Apr 19, 2022
1 parent 3ba7106 commit 9a4a4be
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 30 deletions.
27 changes: 23 additions & 4 deletions polars/polars-core/src/frame/groupby/aggregations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl Series {
}
}

#[cfg(feature = "private")]
#[doc(hidden)]
pub fn agg_valid_count(&self, groups: &GroupsProxy) -> Option<Series> {
match groups {
GroupsProxy::Idx(groups) => agg_helper_idx_on_all::<IdxType, _>(groups, |idx| {
Expand Down Expand Up @@ -110,7 +110,7 @@ impl Series {
}
}

#[cfg(feature = "private")]
#[doc(hidden)]
pub fn agg_first(&self, groups: &GroupsProxy) -> Series {
let out = match groups {
GroupsProxy::Idx(groups) => {
Expand Down Expand Up @@ -144,7 +144,7 @@ impl Series {
self.restore_logical(out)
}

#[cfg(feature = "private")]
#[doc(hidden)]
pub fn agg_n_unique(&self, groups: &GroupsProxy) -> Option<Series> {
match groups {
GroupsProxy::Idx(groups) => agg_helper_idx_on_all::<IdxType, _>(groups, |idx| {
Expand All @@ -169,7 +169,26 @@ impl Series {
}
}

#[cfg(feature = "private")]
#[doc(hidden)]
pub fn agg_mean(&self, groups: &GroupsProxy) -> Option<Series> {
use DataType::*;
match self.dtype() {
// risk of overflow
UInt8 | UInt16 | Int8 | Int16 => {
self.cast(&DataType::Float64).unwrap().agg_mean(groups)
}
Float32 => SeriesWrap(self.f32().unwrap().clone()).agg_mean(groups),
Float64 => SeriesWrap(self.f64().unwrap().clone()).agg_mean(groups),
Int32 => self.i32().unwrap().agg_mean(groups),
Int64 => self.i64().unwrap().agg_mean(groups),
UInt32 => self.u32().unwrap().agg_mean(groups),
UInt64 => self.u64().unwrap().agg_mean(groups),
// logical types don't have agg_mean
_ => None,
}
}

#[doc(hidden)]
pub fn agg_last(&self, groups: &GroupsProxy) -> Series {
let out = match groups {
GroupsProxy::Idx(groups) => {
Expand Down
5 changes: 0 additions & 5 deletions polars/polars-core/src/series/implementations/dates_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,6 @@ macro_rules! impl_dyn_series {
self.0.vec_hash_combine(build_hasher, hashes)
}

fn agg_mean(&self, _groups: &GroupsProxy) -> Option<Series> {
// does not make sense on logical
None
}

fn agg_min(&self, groups: &GroupsProxy) -> Option<Series> {
self.0
.agg_min(groups)
Expand Down
5 changes: 0 additions & 5 deletions polars/polars-core/src/series/implementations/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,6 @@ impl private::PrivateSeries for SeriesWrap<DatetimeChunked> {
self.0.vec_hash_combine(build_hasher, hashes)
}

fn agg_mean(&self, _groups: &GroupsProxy) -> Option<Series> {
// does not make sense on logical
None
}

fn agg_min(&self, groups: &GroupsProxy) -> Option<Series> {
self.0.agg_min(groups).map(|ca| {
ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone())
Expand Down
5 changes: 0 additions & 5 deletions polars/polars-core/src/series/implementations/duration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,6 @@ impl private::PrivateSeries for SeriesWrap<DurationChunked> {
self.0.vec_hash_combine(build_hasher, hashes)
}

fn agg_mean(&self, _groups: &GroupsProxy) -> Option<Series> {
// does not make sense on logical
None
}

fn agg_min(&self, groups: &GroupsProxy) -> Option<Series> {
self.0
.agg_min(groups)
Expand Down
4 changes: 0 additions & 4 deletions polars/polars-core/src/series/implementations/floats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,6 @@ macro_rules! impl_dyn_series {
self.0.vec_hash_combine(build_hasher, hashes)
}

fn agg_mean(&self, groups: &GroupsProxy) -> Option<Series> {
self.agg_mean(groups)
}

fn agg_min(&self, groups: &GroupsProxy) -> Option<Series> {
self.0.agg_min(groups)
}
Expand Down
4 changes: 0 additions & 4 deletions polars/polars-core/src/series/implementations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,6 @@ macro_rules! impl_dyn_series {
self.0.vec_hash_combine(build_hasher, hashes)
}

fn agg_mean(&self, groups: &GroupsProxy) -> Option<Series> {
self.0.agg_mean(groups)
}

fn agg_min(&self, groups: &GroupsProxy) -> Option<Series> {
self.0.agg_min(groups)
}
Expand Down
3 changes: 0 additions & 3 deletions polars/polars-core/src/series/series_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,6 @@ pub(crate) mod private {
fn vec_hash_combine(&self, _build_hasher: RandomState, _hashes: &mut [u64]) {
invalid_operation_panic!(self)
}
fn agg_mean(&self, _groups: &GroupsProxy) -> Option<Series> {
None
}
fn agg_min(&self, _groups: &GroupsProxy) -> Option<Series> {
None
}
Expand Down
20 changes: 20 additions & 0 deletions py-polars/tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,23 @@ def test_agg_after_head() -> None:
out = out.sort("a")

assert out.frame_equal(expected)


def test_overflow_uint16_agg_mean() -> None:
assert (
pl.DataFrame(
{
"col1": ["A" for _ in range(1025)],
"col3": [64 for i in range(1025)],
}
)
.with_columns(
[
pl.col("col3").cast(pl.UInt16),
]
)
.groupby(["col1"])
.agg(pl.col("col3").mean())
.to_dict(False)
== {"col1": ["A"], "col3": [64.0]}
)

0 comments on commit 9a4a4be

Please sign in to comment.