Skip to content

Commit

Permalink
fix mean overflow (#2527)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Feb 3, 2022
1 parent 06ed39b commit 43865fc
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 23 deletions.
47 changes: 28 additions & 19 deletions polars/polars-core/src/chunked_array/ops/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,28 +90,37 @@ where
self.sum().map(|v| v.to_f64().unwrap() / len)
}
_ => {
let mut acc = None;
let len = (self.len() - self.null_count()) as f64;

let mut update_acc = |val: f64| match &acc {
None => acc = Some(val),
Some(_acc) => acc = Some(_acc + val),
};

for arr in self.downcast_iter() {
if arr.null_count() > 0 {
for v in arr.into_iter().flatten() {
let val = v.to_f64().unwrap();
update_acc(val)
}
} else {
for v in arr.values().as_slice() {
let val = v.to_f64().unwrap();
update_acc(val)
let null_count = self.null_count();
let len = self.len();
if null_count == len {
None
} else {
let mut acc = 0.0;
let len = (len - null_count) as f64;

for arr in self.downcast_iter() {
if arr.null_count() > 0 {
for v in arr.into_iter().flatten() {
// safety
// all these types can be coerced to f64
unsafe {
let val = v.to_f64().unwrap_unchecked();
acc += val
}
}
} else {
for v in arr.values().as_slice() {
// safety
// all these types can be coerced to f64
unsafe {
let val = v.to_f64().unwrap_unchecked();
acc += val
}
}
}
}
Some(acc / len)
}
acc.map(|acc| acc / len)
}
}
}
Expand Down
4 changes: 1 addition & 3 deletions polars/polars-lazy/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -740,9 +740,7 @@ impl DefaultPlanner {
Context::Default => {
let function = NoEq::new(Arc::new(move |s: &mut [Series]| {
let s = std::mem::take(&mut s[0]);
let len = s.len() as f64;
parallel_op_series(|s| Ok(s.sum_as_series()), s, None)
.map(|s| s.cast(&DataType::Float64).unwrap() / len)
Ok(s.mean_as_series())
})
as Arc<dyn SeriesUdf>);
Ok(Arc::new(ApplyExpr {
Expand Down
1 change: 0 additions & 1 deletion py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2854,7 +2854,6 @@ def upsample(
... time_column="time", every="1mo", by="groups", maintain_order=True
... ).select(pl.all().forward_fill())
... )
shape: (7, 3)
┌─────────────────────┬────────┬────────┐
│ time ┆ groups ┆ values │
Expand Down
15 changes: 15 additions & 0 deletions py-polars/tests/db-benchmark/various.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,24 @@

import time

import numpy as np

import polars as pl

# https://github.com/pola-rs/polars/issues/1942
t0 = time.time()
pl.repeat(float("nan"), 2 << 12).sort()
assert (time.time() - t0) < 1

# test mean overflow issues
np.random.seed(1)
mean = 769.5607652
df = pl.DataFrame(np.random.randint(500, 1040, 5000000), columns=["value"])
assert np.isclose(df.with_column(pl.mean("value"))[0, 0], mean)
assert np.isclose(
df.with_column(pl.col("value").cast(pl.Int32)).with_column(pl.mean("value"))[0, 0],
mean,
)
assert np.isclose(
df.with_column(pl.col("value").cast(pl.Int32)).get_column("value").mean(), mean
)

0 comments on commit 43865fc

Please sign in to comment.