Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions src/distributions/weighted_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,15 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();

let zero = <X as Default>::default();
if total_weight < zero {
if !(total_weight >= zero) {
return Err(WeightedError::InvalidWeight);
}

let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
for w in iter {
if *w.borrow() < zero {
// Note that `!(w >= x)` is not equivalent to `w < x` for partially
// ordered types due to NaNs which are equal to nothing.
if !(w.borrow() >= &zero) {
return Err(WeightedError::InvalidWeight);
}
weights.push(total_weight.clone());
Expand Down Expand Up @@ -158,7 +160,7 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
return Err(WeightedError::InvalidWeight);
}
}
if *w < zero {
if !(*w >= zero) {
return Err(WeightedError::InvalidWeight);
}
if i >= self.cumulative_weights.len() + 1 {
Expand Down Expand Up @@ -256,6 +258,30 @@ mod test {
assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight);
}

#[test]
fn test_accepting_nan(){
assert_eq!(
WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(),
WeightedError::InvalidWeight,
);
assert_eq!(
WeightedIndex::new(&[core::f32::NAN]).unwrap_err(),
WeightedError::InvalidWeight,
);
assert_eq!(
WeightedIndex::new(&[0.5, core::f32::NAN]).unwrap_err(),
WeightedError::InvalidWeight,
);

assert_eq!(
WeightedIndex::new(&[0.5, 7.0])
.unwrap()
.update_weights(&[(0, &core::f32::NAN)])
.unwrap_err(),
WeightedError::InvalidWeight,
)
}


#[test]
#[cfg_attr(miri, ignore)] // Miri is too slow
Expand Down Expand Up @@ -399,8 +425,8 @@ pub enum WeightedError {
/// The provided weight collection contains no items.
NoItem,

/// A weight is either less than zero, greater than the supported maximum or
/// otherwise invalid.
/// A weight is either less than zero, greater than the supported maximum,
/// NaN, or otherwise invalid.
InvalidWeight,

/// All items in the provided weight collection are zero.
Expand Down