Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add methods weight, weights, and total_weight to weighted_index.rs #1420

1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ A [separate changelog is kept for rand_core](rand_core/CHANGELOG.md).
You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.html) useful.

## [Unreleased]
- Add `rand::distributions::WeightedIndex::{weight, weights, total_weight}` (#1420)
- Bump the MSRV to 1.61.0

## [0.9.0-alpha.1] - 2024-03-18
Expand Down
188 changes: 188 additions & 0 deletions src/distributions/weighted_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use core::fmt;

// Note that this whole module is only imported if feature="alloc" is enabled.
use alloc::vec::Vec;
use core::fmt::Debug;

#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -243,6 +244,124 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
}
}

/// A lazy-loading iterator over the weights of a `WeightedIndex` distribution.
/// This is returned by [`WeightedIndex::weights`].
pub struct WeightedIndexIter<'a, X: SampleUniform + PartialOrd> {
weighted_index: &'a WeightedIndex<X>,
index: usize,
}

impl<'a, X> Debug for WeightedIndexIter<'a, X>
where
X: SampleUniform + PartialOrd + Debug,
X::Sampler: Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WeightedIndexIter")
.field("weighted_index", &self.weighted_index)
.field("index", &self.index)
.finish()
}
}

impl<'a, X> Clone for WeightedIndexIter<'a, X>
where
X: SampleUniform + PartialOrd,
{
fn clone(&self) -> Self {
WeightedIndexIter {
weighted_index: self.weighted_index,
index: self.index,
}
}
}

impl<'a, X> Iterator for WeightedIndexIter<'a, X>
where
X: for<'b> ::core::ops::SubAssign<&'b X>
+ SampleUniform
+ PartialOrd
+ Clone,
{
type Item = X;

fn next(&mut self) -> Option<Self::Item> {
match self.weighted_index.weight(self.index) {
None => None,
Some(weight) => {
self.index += 1;
Some(weight)
}
}
}
}

impl<X: SampleUniform + PartialOrd + Clone> WeightedIndex<X> {
/// Returns the weight at the given index, if it exists.
///
/// If the index is out of bounds, this will return `None`.
///
/// # Example
///
/// ```
/// use rand::distributions::WeightedIndex;
///
/// let weights = [0, 1, 2];
/// let dist = WeightedIndex::new(&weights).unwrap();
/// assert_eq!(dist.weight(0), Some(0));
/// assert_eq!(dist.weight(1), Some(1));
/// assert_eq!(dist.weight(2), Some(2));
/// assert_eq!(dist.weight(3), None);
/// ```
pub fn weight(&self, index: usize) -> Option<X>
where
X: for<'a> ::core::ops::SubAssign<&'a X>
{
let mut weight = if index < self.cumulative_weights.len() {
self.cumulative_weights[index].clone()
} else if index == self.cumulative_weights.len() {
self.total_weight.clone()
} else {
return None;
};
if index > 0 {
weight -= &self.cumulative_weights[index - 1];
}
Some(weight)
}

/// Returns a lazy-loading iterator containing the current weights of this distribution.
///
/// If this distribution has not been updated since its creation, this will return the
/// same weights as were passed to `new`.
///
/// # Example
///
/// ```
/// use rand::distributions::WeightedIndex;
///
/// let weights = [1, 2, 3];
/// let mut dist = WeightedIndex::new(&weights).unwrap();
/// assert_eq!(dist.weights().collect::<Vec<_>>(), vec![1, 2, 3]);
/// dist.update_weights(&[(0, &2)]).unwrap();
/// assert_eq!(dist.weights().collect::<Vec<_>>(), vec![2, 2, 3]);
/// ```
pub fn weights(&self) -> WeightedIndexIter<'_, X>
where
X: for<'a> ::core::ops::SubAssign<&'a X>
{
WeightedIndexIter {
weighted_index: self,
index: 0,
}
}

/// Returns the sum of all weights in this distribution.
pub fn total_weight(&self) -> X {
self.total_weight.clone()
}
}

impl<X> Distribution<usize> for WeightedIndex<X>
where
X: SampleUniform + PartialOrd,
Expand Down Expand Up @@ -458,6 +577,75 @@ mod test {
}
}

#[test]
fn test_update_weights_errors() {
let data = [
(
&[1i32, 0, 0][..],
&[(0, &0)][..],
WeightError::InsufficientNonZero,
),
(
&[10, 10, 10, 10][..],
&[(1, &-11)][..],
WeightError::InvalidWeight, // A weight is negative
),
(
&[1, 2, 3, 4, 5][..],
&[(1, &5), (0, &5)][..], // Wrong order
WeightError::InvalidInput,
),
(
&[1][..],
&[(1, &1)][..], // Index too large
WeightError::InvalidInput,
),
];

for (weights, update, err) in data.iter() {
let total_weight = weights.iter().sum::<i32>();
let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
assert_eq!(distr.total_weight, total_weight);
match distr.update_weights(update) {
Ok(_) => panic!("Expected update_weights to fail, but it succeeded"),
Err(e) => assert_eq!(e, *err),
}
}
}

#[test]
fn test_weight_at() {
let data = [
&[1][..],
&[10, 2, 3, 4][..],
&[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
&[u32::MAX][..],
];

for weights in data.iter() {
let distr = WeightedIndex::new(weights.to_vec()).unwrap();
for (i, weight) in weights.iter().enumerate() {
assert_eq!(distr.weight(i), Some(*weight));
}
assert_eq!(distr.weight(weights.len()), None);
}
}

#[test]
fn test_weights() {
let data = [
&[1][..],
&[10, 2, 3, 4][..],
&[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
&[u32::MAX][..],
];

for weights in data.iter() {
let distr = WeightedIndex::new(weights.to_vec()).unwrap();
assert_eq!(distr.weights().collect::<Vec<_>>(), weights.to_vec());
}
}

#[test]
fn value_stability() {
fn test_samples<X: Weight + SampleUniform + PartialOrd, I>(
Expand Down