Skip to content

Commit

Permalink
fix(rust, python): fix overflow in partitioned groupby mean of int32/… (
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 14, 2022
1 parent 5e60fc5 commit b995b36
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
10 changes: 9 additions & 1 deletion polars/polars-lazy/src/physical_plan/expressions/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,15 @@ impl PartitionedAggregation for AggregationExpr {
#[cfg(feature = "dtype-struct")]
GroupByMethod::Mean => {
let new_name = series.name().to_string();
let mut agg_s = series.agg_sum(groups);

// ensure we don't overflow
// the all 8 and 16 bits integers are already upcasted to int16 on `agg_sum`
let mut agg_s = if matches!(series.dtype(), DataType::Int32 | DataType::UInt32)
{
series.cast(&DataType::Int64).unwrap().agg_sum(groups)
} else {
series.agg_sum(groups)
};
agg_s.rename(&new_name);

if !agg_s.dtype().is_numeric() {
Expand Down
14 changes: 14 additions & 0 deletions py-polars/tests/slow/test_overflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import polars as pl


def test_overflow_mean_partitioned_groupby_5194() -> None:
for dtype in [pl.Int32, pl.UInt32]:
df = pl.DataFrame(
[
pl.Series("data", [10_00_00_00] * 100_000, dtype=dtype),
pl.Series("group", [1, 2] * 50_000, dtype=dtype),
]
)
assert df.groupby("group").agg(pl.col("data").mean()).sort(by="group").to_dict(
False
) == {"group": [1, 2], "data": [10000000.0, 10000000.0]}

0 comments on commit b995b36

Please sign in to comment.