Skip to content

Commit

Permalink
fix median float dispatch and fix overflowing mean aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 30, 2022
1 parent 3587b6a commit 494eaa3
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 2 deletions.
40 changes: 38 additions & 2 deletions polars/polars-core/src/chunked_array/ops/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,36 @@ where
}

fn mean(&self) -> Option<f64> {
let len = (self.len() - self.null_count()) as f64;
self.sum().map(|v| v.to_f64().unwrap() / len)
match self.dtype() {
DataType::Float64 => {
let len = (self.len() - self.null_count()) as f64;
self.sum().map(|v| v.to_f64().unwrap() / len)
}
_ => {
let mut acc = None;
let len = (self.len() - self.null_count()) as f64;

let mut update_acc = |val: f64| match &acc {
None => acc = Some(val),
Some(_acc) => acc = Some(_acc + val),
};

for arr in self.downcast_iter() {
if arr.null_count() > 0 {
for v in arr.into_iter().flatten() {
let val = v.to_f64().unwrap();
update_acc(val)
}
} else {
for v in arr.values().as_slice() {
let val = v.to_f64().unwrap();
update_acc(val)
}
}
}
acc.map(|acc| acc / len)
}
}
}
}

Expand Down Expand Up @@ -1348,4 +1376,12 @@ mod test {
< 0.0001
);
}

#[test]
fn test_median_floats() {
let a = Series::new("a", &[1.0f64, 2.0, 3.0]);
let expected = Series::new("a", [2.0f64]);
assert!(a.median_as_series().series_equal_missing(&expected));
assert_eq!(a.median(), Some(2.0f64))
}
}
4 changes: 4 additions & 0 deletions polars/polars-core/src/series/implementations/floats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@ macro_rules! impl_dyn_series {
self.0.mean()
}

fn median(&self) -> Option<f64> {
self.0.median().map(|v| v as f64)
}

fn take(&self, indices: &UInt32Chunked) -> Result<Series> {
let indices = if indices.chunks.len() > 1 {
Cow::Owned(indices.rechunk())
Expand Down
5 changes: 5 additions & 0 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1453,3 +1453,8 @@ def test_duration_extract_times() -> None:

expected = pl.Series("b", [3600 * 24 * int(1e9)])
verify_series_and_expr_api(duration, expected, "dt.nanoseconds")


def test_mean_overflow() -> None:
arr = np.array([255] * (1 << 17), dtype="int16")
assert arr.mean() == 255.0

0 comments on commit 494eaa3

Please sign in to comment.