Skip to content

Commit

Permalink
fix aggregation with nans
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 11, 2020
1 parent fb1d22a commit 4c4ec9b
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 11 deletions.
2 changes: 1 addition & 1 deletion polars/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "polars"
version = "0.5.0"
version = "0.5.1"
authors = ["ritchie46 <ritchie46@gmail.com>"]
edition = "2018"
license = "MIT"
Expand Down
82 changes: 72 additions & 10 deletions polars/src/chunked_array/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,46 @@ use crate::datatypes::BooleanChunked;
use crate::{datatypes::PolarsNumericType, prelude::*};
use arrow::compute;
use num::{Num, NumCast, ToPrimitive};
use std::cmp::PartialOrd;
use std::cmp::{Ordering, PartialOrd};
use std::ops::{Add, Div};

macro_rules! cmp_float_with_nans {
($a:expr, $b:expr, $precision:ty) => {{
let a: $precision = NumCast::from($a).unwrap();
let b: $precision = NumCast::from($b).unwrap();
match (a.is_nan(), b.is_nan()) {
(true, true) => Ordering::Equal,
(true, false) => Ordering::Greater,
(false, true) => Ordering::Less,
(false, false) => a.partial_cmp(&b).unwrap(),
}
}};
}

macro_rules! agg_float_with_nans {
($self:ident, $agg_method:ident, $precision:ty) => {{
if $self.null_count() == 0 {
$self.into_no_null_iter()
.$agg_method(|&a, &b| cmp_float_with_nans!(a, b, $precision))
} else {
$self.into_iter()
.filter(|opt| opt.is_some())
.map(|opt| opt.unwrap())
.$agg_method(|&a, &b| cmp_float_with_nans!(a, b, $precision))
}

}}
}

impl<T> ChunkAgg<T::Native> for ChunkedArray<T>
where
T: PolarsNumericType,
T::Native: Add<Output = T::Native> + PartialOrd + Div<Output = T::Native> + Num + NumCast,
T::Native: Add<Output = T::Native>
+ PartialOrd
+ Div<Output = T::Native>
+ Num
+ NumCast
+ ToPrimitive,
{
fn sum(&self) -> Option<T::Native> {
self.downcast_chunks()
Expand All @@ -26,17 +59,35 @@ where
}

fn min(&self) -> Option<T::Native> {
self.downcast_chunks()
.iter()
.filter_map(|&a| compute::min(a))
.fold_first(|acc, v| if acc < v { acc } else { v })
match T::get_data_type() {
ArrowDataType::Float32 => {
agg_float_with_nans!(self, min_by, f32)
},
ArrowDataType::Float64 => {
agg_float_with_nans!(self, min_by, f64)
}
_ => self
.downcast_chunks()
.iter()
.filter_map(|&a| compute::min(a))
.fold_first(|acc, v| if acc > v { acc } else { v }),
}
}

fn max(&self) -> Option<T::Native> {
self.downcast_chunks()
.iter()
.filter_map(|&a| compute::max(a))
.fold_first(|acc, v| if acc > v { acc } else { v })
match T::get_data_type() {
ArrowDataType::Float32 => {
agg_float_with_nans!(self, max_by, f32)
},
ArrowDataType::Float64 => {
agg_float_with_nans!(self, max_by, f64)
}
_ => self
.downcast_chunks()
.iter()
.filter_map(|&a| compute::max(a))
.fold_first(|acc, v| if acc > v { acc } else { v }),
}
}

fn mean(&self) -> Option<T::Native> {
Expand Down Expand Up @@ -129,6 +180,17 @@ impl ChunkAgg<u8> for BooleanChunked {
mod test {
use crate::prelude::*;

#[test]
fn test_agg_float() {
let ca1 = Float32Chunked::new_from_slice("a", &[1.0, f32::NAN]);
let ca2 = Float32Chunked::new_from_slice("b", &[f32::NAN, 1.0]);
assert_eq!(ca1.min(), ca2.min());
let ca1 = Float64Chunked::new_from_slice("a", &[1.0, f64::NAN]);
let ca2 = Float64Chunked::new_from_slice("b", &[f64::NAN, 1.0]);
assert_eq!(ca1.min(), ca2.min());
println!("{:?}", (ca1.min(), ca2.min()))
}

#[test]
fn test_median() {
let ca = UInt32Chunked::new_from_opt_slice(
Expand Down

0 comments on commit 4c4ec9b

Please sign in to comment.