Skip to content
Merged
23 changes: 23 additions & 0 deletions src/maybe_nan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,15 @@ where
A: 'a,
F: FnMut(B, &'a A::NotNan) -> B;

/// Traverse the non-NaN elements and their indices and apply a fold,
/// returning the resulting value.
///
/// Elements are visited in arbitrary order.
fn indexed_fold_skipnan<'a, F, B>(&'a self, init: B, f: F) -> B
Copy link
Contributor Author

@phungleson phungleson Mar 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, I wonder if it is better if we introduce indexed_iter_skipnan instead of indexed_fold_skipnan so that it can be used for different purposes, i.e. indexed_iter_skipnan().fold() is one of the usecases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't mind adding indexed_iter_skipnan() in addition to indexed_fold_skipnan(), but the semantics of indexed_iter_skipnan().fold() and indexed_fold_skipnan() are a bit different. (I'd assume that indexed_iter_skipnan() would iterate in standard order to be consistent with indexed_iter(), while I wouldn't assume that for indexed_fold_skipnan().)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, I forgot about the order semantic, in that case, perhaps indexed_iter_skipnan can be added later when it has actual usecase, wdyt? is there anything you want me to look into before merging this?

where
A: 'a,
F: FnMut(B, (D::Pattern, &'a A::NotNan)) -> B;

/// Visit each non-NaN element in the array by calling `f` on each element.
///
/// Elements are visited in arbitrary order.
Expand Down Expand Up @@ -302,6 +311,20 @@ where
})
}

fn indexed_fold_skipnan<'a, F, B>(&'a self, init: B, mut f: F) -> B
where
A: 'a,
F: FnMut(B, (D::Pattern, &'a A::NotNan)) -> B,
{
self.indexed_iter().fold(init, |acc, (idx, elem)| {
if let Some(not_nan) = elem.try_as_not_nan() {
f(acc, (idx, not_nan))
} else {
acc
}
})
}

fn visit_skipnan<'a, F>(&'a self, mut f: F)
where
A: 'a,
Expand Down
98 changes: 98 additions & 0 deletions src/quantile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,33 @@ where
where
A: PartialOrd;

/// Finds the index of the minimum value of the array skipping NaN values.
///
/// Returns `None` if the array is empty or none of the values in the array
/// are non-NaN values.
///
/// Even if there are multiple (equal) elements that are minima, only one
/// index is returned. (Which one is returned is unspecified and may depend
/// on the memory layout of the array.)
///
/// # Example
///
/// ```
/// extern crate ndarray;
/// extern crate ndarray_stats;
///
/// use ndarray::array;
/// use ndarray_stats::QuantileExt;
///
/// let a = array![[::std::f64::NAN, 3., 5.],
/// [2., 0., 6.]];
/// assert_eq!(a.argmin_skipnan(), Some((1, 1)));
/// ```
fn argmin_skipnan(&self) -> Option<D::Pattern>
where
A: MaybeNan,
A::NotNan: Ord;

/// Finds the elementwise minimum of the array.
///
/// Returns `None` if any of the pairwise orderings tested by the function
Expand Down Expand Up @@ -269,6 +296,33 @@ where
where
A: PartialOrd;

/// Finds the index of the maximum value of the array skipping NaN values.
///
/// Returns `None` if the array is empty or none of the values in the array
/// are non-NaN values.
///
/// Even if there are multiple (equal) elements that are maxima, only one
/// index is returned. (Which one is returned is unspecified and may depend
/// on the memory layout of the array.)
///
/// # Example
///
/// ```
/// extern crate ndarray;
/// extern crate ndarray_stats;
///
/// use ndarray::array;
/// use ndarray_stats::QuantileExt;
///
/// let a = array![[::std::f64::NAN, 3., 5.],
/// [2., 0., 6.]];
/// assert_eq!(a.argmax_skipnan(), Some((1, 2)));
/// ```
fn argmax_skipnan(&self) -> Option<D::Pattern>
where
A: MaybeNan,
A::NotNan: Ord;

/// Finds the elementwise maximum of the array.
///
/// Returns `None` if any of the pairwise orderings tested by the function
Expand Down Expand Up @@ -369,6 +423,28 @@ where
Some(current_pattern_min)
}

fn argmin_skipnan(&self) -> Option<D::Pattern>
where
A: MaybeNan,
A::NotNan: Ord,
{
let mut pattern_min = D::zeros(self.ndim()).into_pattern();
let min = self.indexed_fold_skipnan(None, |current_min, (pattern, elem)| {
Some(match current_min {
Some(m) if (m <= elem) => m,
_ => {
pattern_min = pattern;
elem
}
})
});
if min.is_some() {
Some(pattern_min)
} else {
None
}
}

fn min(&self) -> Option<&A>
where
A: PartialOrd,
Expand Down Expand Up @@ -411,6 +487,28 @@ where
Some(current_pattern_max)
}

fn argmax_skipnan(&self) -> Option<D::Pattern>
where
A: MaybeNan,
A::NotNan: Ord,
{
let mut pattern_max = D::zeros(self.ndim()).into_pattern();
let max = self.indexed_fold_skipnan(None, |current_max, (pattern, elem)| {
Some(match current_max {
Some(m) if m >= elem => m,
_ => {
pattern_max = pattern;
elem
}
})
});
if max.is_some() {
Some(pattern_max)
} else {
None
}
}

fn max(&self) -> Option<&A>
where
A: PartialOrd,
Expand Down
65 changes: 65 additions & 0 deletions tests/quantile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,37 @@ quickcheck! {
}
}

#[test]
fn test_argmin_skipnan() {
let a = array![[1., 5., 3.], [2., 0., 6.]];
assert_eq!(a.argmin_skipnan(), Some((1, 1)));

let a = array![[1., 5., 3.], [2., ::std::f64::NAN, 6.]];
assert_eq!(a.argmin_skipnan(), Some((0, 0)));

let a = array![[::std::f64::NAN, 5., 3.], [2., ::std::f64::NAN, 6.]];
assert_eq!(a.argmin_skipnan(), Some((1, 0)));

let a: Array2<f64> = array![[], []];
assert_eq!(a.argmin_skipnan(), None);

let a = arr2(&[[::std::f64::NAN; 2]; 2]);
assert_eq!(a.argmin_skipnan(), None);
}

quickcheck! {
fn argmin_skipnan_matches_min_skipnan(data: Vec<Option<i32>>) -> bool {
let a = Array1::from(data);
let min = a.min_skipnan();
let argmin = a.argmin_skipnan();
if min.is_none() {
argmin == None
} else {
a[argmin.unwrap()] == *min
}
}
}

#[test]
fn test_min() {
let a = array![[1, 5, 3], [2, 0, 6]];
Expand Down Expand Up @@ -81,6 +112,40 @@ quickcheck! {
}
}

#[test]
fn test_argmax_skipnan() {
let a = array![[1., 5., 3.], [2., 0., 6.]];
assert_eq!(a.argmax_skipnan(), Some((1, 2)));

let a = array![[1., 5., 3.], [2., ::std::f64::NAN, ::std::f64::NAN]];
assert_eq!(a.argmax_skipnan(), Some((0, 1)));

let a = array![
[::std::f64::NAN, ::std::f64::NAN, 3.],
[2., ::std::f64::NAN, 6.]
];
assert_eq!(a.argmax_skipnan(), Some((1, 2)));

let a: Array2<f64> = array![[], []];
assert_eq!(a.argmax_skipnan(), None);

let a = arr2(&[[::std::f64::NAN; 2]; 2]);
assert_eq!(a.argmax_skipnan(), None);
}

quickcheck! {
fn argmax_skipnan_matches_max_skipnan(data: Vec<Option<i32>>) -> bool {
let a = Array1::from(data);
let max = a.max_skipnan();
let argmax = a.argmax_skipnan();
if max.is_none() {
argmax == None
} else {
a[argmax.unwrap()] == *max
}
}
}

#[test]
fn test_max() {
let a = array![[1, 5, 7], [2, 0, 6]];
Expand Down