Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve algorithm for sampling Beta #1000

Merged
merged 11 commits into from
Sep 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions rand_distr/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
- New `Beta` sampling algorithm for improved performance and accuracy (#1000)

## [0.3.0] - 2020-08-25
- Move alias method for `WeightedIndex` from `rand` (#945)
- Rename `WeightedIndex` to `WeightedAliasIndex` (#1008)
Expand Down
16 changes: 10 additions & 6 deletions rand_distr/benches/distributions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::mem::size_of;
use test::Bencher;

use rand::prelude::*;
use rand_distr::{weighted::WeightedIndex, *};
use rand_distr::*;

// At this time, distributions are optimised for 64-bit platforms.
use rand_pcg::Pcg64Mcg;
Expand Down Expand Up @@ -112,11 +112,15 @@ distr_float!(distr_normal, f64, Normal::new(-1.23, 4.56).unwrap());
distr_float!(distr_log_normal, f64, LogNormal::new(-1.23, 4.56).unwrap());
distr_float!(distr_gamma_large_shape, f64, Gamma::new(10., 1.0).unwrap());
distr_float!(distr_gamma_small_shape, f64, Gamma::new(0.1, 1.0).unwrap());
distr_float!(distr_beta_small_param, f64, Beta::new(0.1, 0.1).unwrap());
distr_float!(distr_beta_large_param_similar, f64, Beta::new(101., 95.).unwrap());
distr_float!(distr_beta_large_param_different, f64, Beta::new(10., 1000.).unwrap());
distr_float!(distr_beta_mixed_param, f64, Beta::new(0.5, 100.).unwrap());
distr_float!(distr_cauchy, f64, Cauchy::new(4.2, 6.9).unwrap());
distr_float!(distr_triangular, f64, Triangular::new(0., 1., 0.9).unwrap());
distr_int!(distr_binomial, u64, Binomial::new(20, 0.7).unwrap());
distr_int!(distr_binomial_small, u64, Binomial::new(1000000, 1e-30).unwrap());
distr_int!(distr_poisson, u64, Poisson::new(4.0).unwrap());
distr_float!(distr_poisson, f64, Poisson::new(4.0).unwrap());
distr!(distr_bernoulli, bool, Bernoulli::new(0.18).unwrap());
distr_arr!(distr_circle, [f64; 2], UnitCircle);
distr_arr!(distr_sphere, [f64; 3], UnitSphere);
Expand All @@ -127,10 +131,10 @@ distr_int!(distr_weighted_u32, usize, WeightedIndex::new(&[1u32, 2, 3, 4, 12, 0,
distr_int!(distr_weighted_f64, usize, WeightedIndex::new(&[1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap());
distr_int!(distr_weighted_large_set, usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap());

distr_int!(distr_weighted_alias_method_i8, usize, weighted::alias_method::WeightedIndex::new(vec![1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap());
distr_int!(distr_weighted_alias_method_u32, usize, weighted::alias_method::WeightedIndex::new(vec![1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap());
distr_int!(distr_weighted_alias_method_f64, usize, weighted::alias_method::WeightedIndex::new(vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap());
distr_int!(distr_weighted_alias_method_large_set, usize, weighted::alias_method::WeightedIndex::new((0..10000).rev().chain(1..10001).collect()).unwrap());
distr_int!(distr_weighted_alias_method_i8, usize, WeightedAliasIndex::new(vec![1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap());
distr_int!(distr_weighted_alias_method_u32, usize, WeightedAliasIndex::new(vec![1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap());
distr_int!(distr_weighted_alias_method_f64, usize, WeightedAliasIndex::new(vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap());
distr_int!(distr_weighted_alias_method_large_set, usize, WeightedAliasIndex::new((0..10000).rev().chain(1..10001).collect()).unwrap());


#[bench]
Expand Down
180 changes: 165 additions & 15 deletions rand_distr/src/gamma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,38 @@ where
}
}

/// The algorithm used for sampling the Beta distribution.
///
/// Reference:
///
/// R. C. H. Cheng (1978).
/// Generating beta variates with nonintegral shape parameters.
/// Communications of the ACM 21, 317-322.
/// https://doi.org/10.1145/359460.359482
#[derive(Clone, Copy, Debug)]
enum BetaAlgorithm<N> {
BB(BB<N>),
BC(BC<N>),
}

/// Algorithm BB for `min(alpha, beta) > 1`.
#[derive(Clone, Copy, Debug)]
struct BB<N> {
alpha: N,
beta: N,
gamma: N,
}

/// Algorithm BC for `min(alpha, beta) <= 1`.
#[derive(Clone, Copy, Debug)]
struct BC<N> {
alpha: N,
beta: N,
delta: N,
kappa1: N,
kappa2: N,
}

/// The Beta distribution with shape parameters `alpha` and `beta`.
///
/// # Example
Expand All @@ -510,12 +542,10 @@ where
pub struct Beta<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
gamma_a: Gamma<F>,
gamma_b: Gamma<F>,
a: F, b: F, switched_params: bool,
algorithm: BetaAlgorithm<F>,
}

/// Error type returned from `Beta::new`.
Expand All @@ -542,31 +572,142 @@ impl std::error::Error for BetaError {}
impl<F> Beta<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
/// Construct an object representing the `Beta(alpha, beta)`
/// distribution.
pub fn new(alpha: F, beta: F) -> Result<Beta<F>, BetaError> {
Ok(Beta {
gamma_a: Gamma::new(alpha, F::one()).map_err(|_| BetaError::AlphaTooSmall)?,
gamma_b: Gamma::new(beta, F::one()).map_err(|_| BetaError::BetaTooSmall)?,
})
if !(alpha > F::zero()) {
return Err(BetaError::AlphaTooSmall);
}
if !(beta > F::zero()) {
return Err(BetaError::BetaTooSmall);
}
// From now on, we use the notation from the reference,
// i.e. `alpha` and `beta` are renamed to `a0` and `b0`.
let (a0, b0) = (alpha, beta);
let (a, b, switched_params) = if a0 < b0 {
(a0, b0, false)
} else {
(b0, a0, true)
};
if a > F::one() {
// Algorithm BB
let alpha = a + b;
let beta = ((alpha - F::from(2.).unwrap())
/ (F::from(2.).unwrap()*a*b - alpha)).sqrt();
let gamma = a + F::one() / beta;

Ok(Beta {
a, b, switched_params,
algorithm: BetaAlgorithm::BB(BB {
alpha, beta, gamma,
})
})
} else {
// Algorithm BC
//
// Here `a` is the maximum instead of the minimum.
let (a, b, switched_params) = (b, a, !switched_params);
let alpha = a + b;
let beta = F::one() / b;
vks marked this conversation as resolved.
Show resolved Hide resolved
let delta = F::one() + a - b;
vks marked this conversation as resolved.
Show resolved Hide resolved
let kappa1 = delta
* (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap()*b)
vks marked this conversation as resolved.
Show resolved Hide resolved
/ (a*beta - F::from(14. / 18.).unwrap());
vks marked this conversation as resolved.
Show resolved Hide resolved
let kappa2 = F::from(0.25).unwrap()
+ (F::from(0.5).unwrap() + F::from(0.25).unwrap()/delta)*b;
vks marked this conversation as resolved.
Show resolved Hide resolved

Ok(Beta {
a, b, switched_params,
algorithm: BetaAlgorithm::BC(BC {
alpha, beta, delta, kappa1, kappa2,
})
})
}
}
}

impl<F> Distribution<F> for Beta<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
let x = self.gamma_a.sample(rng);
let y = self.gamma_b.sample(rng);
x / (x + y)
let mut w;
match self.algorithm {
BetaAlgorithm::BB(algo) => {
loop {
// 1.
let u1 = rng.sample(Open01);
let u2 = rng.sample(Open01);
let v = algo.beta * (u1 / (F::one() - u1)).ln();
w = self.a * v.exp();
let z = u1*u1 * u2;
let r = algo.gamma * v - F::from(4.).unwrap().ln();
let s = self.a + r - w;
// 2.
if s + F::one() + F::from(5.).unwrap().ln()
>= F::from(5.).unwrap() * z {
break;
}
// 3.
let t = z.ln();
if s >= t {
break;
}
// 4.
if !(r + algo.alpha * (algo.alpha / (self.b + w)).ln() < t) {
break;
}
}
},
BetaAlgorithm::BC(algo) => {
loop {
let z;
// 1.
let u1 = rng.sample(Open01);
let u2 = rng.sample(Open01);
if u1 < F::from(0.5).unwrap() {
// 2.
let y = u1 * u2;
z = u1 * y;
if F::from(0.25).unwrap() * u2 + z - y >= algo.kappa1 {
vks marked this conversation as resolved.
Show resolved Hide resolved
continue;
}
} else {
// 3.
z = u1 * u1 * u2;
if z <= F::from(0.25).unwrap() {
vks marked this conversation as resolved.
Show resolved Hide resolved
let v = algo.beta * (u1 / (F::one() - u1)).ln();
w = self.a * v.exp();
break;
}
// 4.
if z >= algo.kappa2 {
continue;
}
}
// 5.
let v = algo.beta * (u1 / (F::one() - u1)).ln();
w = self.a * v.exp();
if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v)
- F::from(4.).unwrap().ln() < z.ln()) {
break;
};
}
},
};
// 5. for BB, 6. for BC
if !self.switched_params {
if w == F::infinity() {
// Assuming `b` is finite, for large `w`:
return F::one();
}
w / (self.b + w)
} else {
self.b / (self.b + w)
}
}
}

Expand Down Expand Up @@ -636,4 +777,13 @@ mod test {
fn test_beta_invalid_dof() {
Beta::new(0., 0.).unwrap();
}

#[test]
fn test_beta_small_param() {
let beta = Beta::<f64>::new(1e-3, 1e-3).unwrap();
let mut rng = crate::test::rng(206);
for i in 0..1000 {
assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i);
}
}
}
32 changes: 19 additions & 13 deletions rand_distr/tests/value_stability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ fn normal_inverse_gaussian_stability() {
fn pert_stability() {
// mean = 4, var = 12/7
test_samples(860, Pert::new(2., 10., 3.).unwrap(), &[
4.631484136029422f64,
3.307201472321789f64,
3.29995019556348f64,
3.66835483991721f64,
3.514246139933899f64,
4.908681667460367,
4.014196196158352,
2.6489397149197234,
3.4569780580044727,
4.242864311947118,
]);
}

Expand Down Expand Up @@ -200,15 +200,21 @@ fn gamma_stability() {
-2.377641221169782,
]);

// Beta has same special cases as Gamma on each param
// Beta has two special cases:
//
// 1. min(alpha, beta) <= 1
// 2. min(alpha, beta) > 1
test_samples(223, Beta::new(1.0, 0.8).unwrap(), &[
0.6444564f32, 0.357635, 0.4110078, 0.7347192,
]);
test_samples(223, Beta::new(0.7, 1.2).unwrap(), &[
0.6433129944095513f64,
0.5373371199711573,
0.10313293199269491,
0.002472280249144378,
0.8300703726659456,
0.8134131062097899,
0.47912589330631555,
0.25323238071138526,
]);
test_samples(223, Beta::new(3.0, 1.2).unwrap(), &[
0.49563509121756827,
0.9551305482256759,
0.5151181353461637,
0.7551732971235077,
]);
}

Expand Down
1 change: 1 addition & 0 deletions utils/ci/script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ main() {
if [ "0$NIGHTLY" -ge 1 ]; then
$CARGO test $TARGET --all-features
$CARGO test $TARGET --benches --features=nightly
$CARGO test $TARGET --manifest-path rand_distr/Cargo.toml --benches
else
# all stable features:
$CARGO test $TARGET --features=serde1,log,small_rng
Expand Down