Skip to content

Commit

Permalink
Make it safe against overflows
Browse files Browse the repository at this point in the history
  • Loading branch information
xmakro committed Jan 12, 2024
1 parent 75c150f commit 30866d6
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 21 deletions.
3 changes: 2 additions & 1 deletion rand_distr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@
//! - [`UnitBall`] distribution
//! - [`UnitCircle`] distribution
//! - [`UnitDisc`] distribution
//! - Alternative implementation for weighted index sampling
//! - Alternative implementations for weighted index sampling
//! - [`WeightedAliasIndex`] distribution
//! - [`WeightedTreeIndex`] distribution
//! - Misc. distributions
//! - [`InverseGaussian`] distribution
//! - [`NormalInverseGaussian`] distribution
Expand Down
80 changes: 61 additions & 19 deletions rand_distr/src/weighted_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
//! This module contains an implementation of a tree sttructure for sampling random
//! indices with probabilities proportional to a collection of weights.

use core::ops::{Add, AddAssign, Sub, SubAssign};
use core::ops::{Sub, SubAssign};

use super::WeightedError;
use crate::Distribution;
use alloc::{vec, vec::Vec};
use num_traits::Zero;
use num_traits::{Zero, CheckedAdd};
use rand::{distributions::uniform::SampleUniform, Rng};
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -113,7 +113,8 @@ impl<W: Weight> WeightedTreeIndex<W> {
} else {
W::zero()
};
subtotals[i] = weights[i] + left_subtotal + right_subtotal;
let children_subtotal = left_subtotal.checked_add(&right_subtotal).ok_or(WeightedError::Overflow)?;
subtotals[i] = weights[i].checked_add(&children_subtotal).ok_or(WeightedError::Overflow)?;
}
Ok(Self { subtotals })
}
Expand Down Expand Up @@ -163,11 +164,16 @@ impl<W: Weight> WeightedTreeIndex<W> {
if weight < W::zero() {
return Err(WeightedError::InvalidWeight);
}
if let Some(total) = self.subtotals.first() {
if total.checked_add(&weight).is_none() {
return Err(WeightedError::Overflow);
}
}
let mut index = self.len();
self.subtotals.push(weight);
while index != 0 {
index = (index - 1) / 2;
self.subtotals[index] += weight;
self.subtotals[index] = self.subtotals[index].checked_add(&weight).unwrap();
}
Ok(())
}
Expand All @@ -181,10 +187,15 @@ impl<W: Weight> WeightedTreeIndex<W> {
if difference == W::zero() {
return Ok(());
}
self.subtotals[index] += difference;
if let Some(total) = self.subtotals.first() {
if total.checked_add(&difference).is_none() {
return Err(WeightedError::Overflow);
}
}
self.subtotals[index] = self.subtotals[index].checked_add(&difference).unwrap();
while index != 0 {
index = (index - 1) / 2;
self.subtotals[index] += difference;
self.subtotals[index] = self.subtotals[index].checked_add(&difference).unwrap();
}
Ok(())
}
Expand Down Expand Up @@ -246,27 +257,49 @@ pub trait Weight:
+ Copy
+ SampleUniform
+ PartialOrd
+ Add<Output = Self>
+ AddAssign
+ Sub<Output = Self>
+ SubAssign
+ Zero
{
/// Adds two numbers, checking for overflow. If overflow happens, None is returned.
fn checked_add(&self, b: &Self) -> Option<Self>;
}

impl<T> Weight for T where
T: Sized
+ Copy
+ SampleUniform
+ PartialOrd
+ Add<Output = Self>
+ AddAssign
+ Sub<Output = Self>
+ SubAssign
+ Zero
{
macro_rules! impl_weight_for_float {
($T: ident) => {
impl Weight for $T {
fn checked_add(&self, b: &Self) -> Option<Self> {
Some(self + b)
}
}
};
}

macro_rules! impl_weight_for_int {
($T: ident) => {
impl Weight for $T {
fn checked_add(&self, b: &Self) -> Option<Self> {
CheckedAdd::checked_add(self, b)
}
}
};
}

impl_weight_for_float!(f64);
impl_weight_for_float!(f32);
impl_weight_for_int!(usize);
impl_weight_for_int!(u128);
impl_weight_for_int!(u64);
impl_weight_for_int!(u32);
impl_weight_for_int!(u16);
impl_weight_for_int!(u8);
impl_weight_for_int!(isize);
impl_weight_for_int!(i128);
impl_weight_for_int!(i64);
impl_weight_for_int!(i32);
impl_weight_for_int!(i16);
impl_weight_for_int!(i8);

#[cfg(test)]
mod test {
use super::*;
Expand All @@ -278,6 +311,15 @@ mod test {
assert_eq!(tree.sample(&mut rng).unwrap_err(), WeightedError::NoItem);
}

#[test]
fn test_overflow_error() {
assert_eq!(WeightedTreeIndex::new(&[i32::MAX, 2]), Err(WeightedError::Overflow));
let mut tree = WeightedTreeIndex::new(&[i32::MAX - 2, 1]).unwrap();
assert_eq!(tree.push(3), Err(WeightedError::Overflow));
assert_eq!(tree.update(1, 4), Err(WeightedError::Overflow));
tree.update(1, 2).unwrap();
}

#[test]
fn test_all_weights_zero_error() {
let tree = WeightedTreeIndex::<f64>::new(&[0.0, 0.0]).unwrap();
Expand Down
4 changes: 3 additions & 1 deletion src/distributions/weighted_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ use serde::{Serialize, Deserialize};
/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where
/// `N` is the number of weights. As an alternative,
/// [`rand_distr::weighted_alias`](https://docs.rs/rand_distr/*/rand_distr/weighted_alias/index.html)
/// supports `O(1)` sampling, but with much higher initialisation cost.
/// supports `O(1)` sampling, but with much higher initialisation cost,
/// and [`rand_distr::weighted_tree`](https://docs.rs/rand_distr/*/rand_distr/weighted_tree/index.html)
/// supports `O(log n)` updates with O
///
/// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its
/// size is the sum of the size of those objects, possibly plus some alignment.
Expand Down

0 comments on commit 30866d6

Please sign in to comment.