Skip to content

Commit

Permalink
Add PertBuilder; allow specification via mean or mode
Browse files Browse the repository at this point in the history
  • Loading branch information
dhardy committed May 16, 2024
1 parent 3888d88 commit b51268e
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 18 deletions.
82 changes: 65 additions & 17 deletions rand_distr/src/pert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use rand::Rng;
/// ```rust
/// use rand_distr::{Pert, Distribution};
///
/// let d = Pert::new(0., 5., 2.5).unwrap();
/// let d = Pert::new(0., 5.).with_mode(2.5).unwrap();
/// let v = d.sample(&mut rand::thread_rng());
/// println!("{} is from a PERT distribution", v);
/// ```
Expand Down Expand Up @@ -75,27 +75,71 @@ where
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
/// Set up the PERT distribution with defined `min`, `max` and `mode`.
/// Construct a PERT distribution with defined `min`, `max`
///
/// This is equivalent to calling `Pert::new_with_shape` with `shape == 4.0`.
/// # Example
///
/// ```
/// use rand_distr::Pert;
/// let pert_dist = Pert::new(0.0, 10.0)
/// .with_shape(3.5)
/// .with_mean(3.0)
/// .unwrap();
/// # let _unused: Pert<f64> = pert_dist;
/// ```
#[inline]
pub fn new(min: F, max: F, mode: F) -> Result<Pert<F>, PertError> {
Pert::new_with_shape(min, max, mode, F::from(4.).unwrap())
pub fn new(min: F, max: F) -> PertBuilder<F> {
let shape = F::from(4.0).unwrap();
PertBuilder { min, max, shape }
}
}

/// Struct used to build a [`Pert`]
#[derive(Debug)]
pub struct PertBuilder<F> {
min: F,
max: F,
shape: F,
}

/// Set up the PERT distribution with defined `min`, `max`, `mode` and
/// `shape`.
pub fn new_with_shape(min: F, max: F, mode: F, shape: F) -> Result<Pert<F>, PertError> {
if !(max > min) {
impl<F> PertBuilder<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
/// Set the shape parameter
///
/// If not specified, this defaults to 4.
#[inline]
pub fn with_shape(mut self, shape: F) -> PertBuilder<F> {
self.shape = shape;
self
}

/// Specify the mean
#[inline]
pub fn with_mean(self, mean: F) -> Result<Pert<F>, PertError> {
let two = F::from(2.0).unwrap();
let mode = ((self.shape + two) * mean - self.min - self.max) / self.shape;
self.with_mode(mode)
}

/// Specify the mode
#[inline]
pub fn with_mode(self, mode: F) -> Result<Pert<F>, PertError> {
if !(self.max > self.min) {
return Err(PertError::RangeTooSmall);
}
if !(mode >= min && max >= mode) {
if !(mode >= self.min && self.max >= mode) {
return Err(PertError::ModeRange);
}
if !(shape >= F::from(0.).unwrap()) {
if !(self.shape >= F::from(0.).unwrap()) {
return Err(PertError::ShapeTooSmall);
}

let (min, max, shape) = (self.min, self.max, self.shape);
let range = max - min;
let v = F::from(1.0).unwrap() + shape * (mode - min) / range;
let w = F::from(1.0).unwrap() + shape * (max - mode) / range;
Expand Down Expand Up @@ -124,34 +168,38 @@ mod test {
#[test]
fn test_pert() {
for &(min, max, mode) in &[(-1., 1., 0.), (1., 2., 1.), (5., 25., 25.)] {
let _distr = Pert::new(min, max, mode).unwrap();
let _distr = Pert::new(min, max).with_mode(mode).unwrap();
// TODO: test correctness
}

for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] {
assert!(Pert::new(min, max, mode).is_err());
assert!(Pert::new(min, max).with_mode(mode).is_err());
}
}

#[test]
fn distributions_can_be_compared() {
assert_eq!(Pert::new(1.0, 3.0, 2.0), Pert::new(1.0, 3.0, 2.0));
let (min, mode, max, shape) = (1.0, 2.0, 3.0, 4.0);
let p1 = Pert::new(min, max).with_mode(mode).unwrap();
let mean = (min + shape * mode + max) / (shape + 2.0);
let p2 = Pert::new(min, max).with_mean(mean).unwrap();
assert_eq!(p1, p2);
}

#[test]
fn mode_almost_half_range() {
assert!(Pert::new(0.0f32, 0.48258883, 0.24129441).is_ok());
assert!(Pert::new(0.0f32, 0.48258883).with_mode(0.24129441).is_ok());
}

#[test]
fn almost_symmetric_about_zero() {
let distr = Pert::new(-10f32, 10f32, f32::EPSILON);
let distr = Pert::new(-10f32, 10f32).with_mode(f32::EPSILON);
assert!(distr.is_ok());
}

#[test]
fn almost_symmetric() {
let distr = Pert::new(0f32, 2f32, 1f32 + f32::EPSILON);
let distr = Pert::new(0f32, 2f32).with_mode(1f32 + f32::EPSILON);
assert!(distr.is_ok());
}
}
2 changes: 1 addition & 1 deletion rand_distr/tests/value_stability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ fn pert_stability() {
// mean = 4, var = 12/7
test_samples(
860,
Pert::new(2., 10., 3.).unwrap(),
Pert::new(2., 10.).with_mode(3.).unwrap(),
&[
4.908681667460367,
4.014196196158352,
Expand Down

0 comments on commit b51268e

Please sign in to comment.