From 4d129e4dfb29f2be4da572b929a2c0c839eea735 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Fri, 10 May 2019 17:48:50 +0100 Subject: [PATCH 1/9] rand_distr: remove dependence on num-traits A custom trait can be fine-tuned to our needs --- rand_distr/Cargo.toml | 1 - rand_distr/src/exponential.rs | 7 +++-- rand_distr/src/gamma.rs | 28 ++++++++++---------- rand_distr/src/lib.rs | 1 + rand_distr/src/normal.rs | 7 +++-- rand_distr/src/utils.rs | 48 ++++++++++++++++++++++++++++++++--- 6 files changed, 66 insertions(+), 26 deletions(-) diff --git a/rand_distr/Cargo.toml b/rand_distr/Cargo.toml index 526ea0bc124..d6e4931ce34 100644 --- a/rand_distr/Cargo.toml +++ b/rand_distr/Cargo.toml @@ -20,4 +20,3 @@ appveyor = { repository = "rust-random/rand" } [dependencies] rand = { path = "..", version = ">=0.5, <=0.7" } -num-traits = "0.2" diff --git a/rand_distr/src/exponential.rs b/rand_distr/src/exponential.rs index bd832f8279d..a4f76194602 100644 --- a/rand_distr/src/exponential.rs +++ b/rand_distr/src/exponential.rs @@ -11,8 +11,7 @@ use rand::Rng; use crate::{ziggurat_tables, Distribution}; -use crate::utils::ziggurat; -use num_traits::Float; +use crate::utils::{ziggurat, Float}; /// Samples floating-point numbers according to the exponential distribution, /// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or @@ -105,10 +104,10 @@ where Exp1: Distribution /// `lambda`. #[inline] pub fn new(lambda: N) -> Result, Error> { - if !(lambda > N::zero()) { + if !(lambda > N::from(0.0)) { return Err(Error::LambdaTooSmall); } - Ok(Exp { lambda_inverse: N::one() / lambda }) + Ok(Exp { lambda_inverse: N::from(1.0) / lambda }) } } diff --git a/rand_distr/src/gamma.rs b/rand_distr/src/gamma.rs index 6035f61107d..2dcbebe09cb 100644 --- a/rand_distr/src/gamma.rs +++ b/rand_distr/src/gamma.rs @@ -15,7 +15,7 @@ use self::ChiSquaredRepr::*; use rand::Rng; use crate::normal::StandardNormal; use crate::{Distribution, Exp1, Exp, Open01}; -use num_traits::Float; +use crate::utils::Float; /// The Gamma distribution `Gamma(shape, scale)` distribution. /// @@ -108,16 +108,16 @@ where StandardNormal: Distribution, Exp1: Distribution, Open01: Distributi /// distribution. #[inline] pub fn new(shape: N, scale: N) -> Result, Error> { - if !(shape > N::zero()) { + if !(shape > N::from(0.0)) { return Err(Error::ShapeTooSmall); } - if !(scale > N::zero()) { + if !(scale > N::from(0.0)) { return Err(Error::ScaleTooSmall); } - let repr = if shape == N::one() { - One(Exp::new(N::one() / scale).map_err(|_| Error::ScaleTooLarge)?) - } else if shape < N::one() { + let repr = if shape == N::from(1.0) { + One(Exp::new(N::from(1.0) / scale).map_err(|_| Error::ScaleTooLarge)?) + } else if shape < N::from(1.0) { Small(GammaSmallShape::new_raw(shape, scale)) } else { Large(GammaLargeShape::new_raw(shape, scale)) @@ -131,8 +131,8 @@ where StandardNormal: Distribution, Open01: Distribution { fn new_raw(shape: N, scale: N) -> GammaSmallShape { GammaSmallShape { - inv_shape: N::one() / shape, - large_shape: GammaLargeShape::new_raw(shape + N::one(), scale) + inv_shape: N::from(1.0) / shape, + large_shape: GammaLargeShape::new_raw(shape + N::from(1.0), scale) } } } @@ -141,10 +141,10 @@ impl GammaLargeShape where StandardNormal: Distribution, Open01: Distribution { fn new_raw(shape: N, scale: N) -> GammaLargeShape { - let d = shape - N::from(1. / 3.).unwrap(); + let d = shape - N::from(1. / 3.); GammaLargeShape { scale, - c: N::one() / (N::from(9.).unwrap() * d).sqrt(), + c: N::from(1.0) / (N::from(9.) * d).sqrt(), d } } @@ -177,8 +177,8 @@ where StandardNormal: Distribution, Open01: Distribution // Marsaglia & Tsang method, 2000 loop { let x: N = rng.sample(StandardNormal); - let v_cbrt = N::one() + self.c * x; - if v_cbrt <= N::zero() { // a^3 <= 0 iff a <= 0 + let v_cbrt = N::from(1.0) + self.c * x; + if v_cbrt <= N::from(0.0) { // a^3 <= 0 iff a <= 0 continue } @@ -186,8 +186,8 @@ where StandardNormal: Distribution, Open01: Distribution let u: N = rng.sample(Open01); let x_sqr = x * x; - if u < N::one() - N::from(0.0331).unwrap() * x_sqr * x_sqr || - u.ln() < N::from(0.5).unwrap() * x_sqr + self.d * (N::one() - v + v.ln()) + if u < N::from(1.0) - N::from(0.0331) * x_sqr * x_sqr || + u.ln() < N::from(0.5) * x_sqr + self.d * (N::from(1.0) - v + v.ln()) { return self.d * v * self.scale } diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index d172854098e..32de237814c 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -75,6 +75,7 @@ pub use self::cauchy::{Cauchy, Error as CauchyError}; pub use self::dirichlet::{Dirichlet, Error as DirichletError}; pub use self::triangular::{Triangular, TriangularError}; pub use self::weibull::{Weibull, Error as WeibullError}; +pub use self::utils::Float; mod unit_sphere; mod unit_circle; diff --git a/rand_distr/src/normal.rs b/rand_distr/src/normal.rs index ea700e54d8a..036edc20569 100644 --- a/rand_distr/src/normal.rs +++ b/rand_distr/src/normal.rs @@ -11,8 +11,7 @@ use rand::Rng; use crate::{ziggurat_tables, Distribution, Open01}; -use crate::utils::ziggurat; -use num_traits::Float; +use crate::utils::{ziggurat, Float}; /// Samples floating-point numbers according to the normal distribution /// `N(0, 1)` (a.k.a. a standard normal, or Gaussian). This is equivalent to @@ -122,7 +121,7 @@ where StandardNormal: Distribution /// standard deviation. #[inline] pub fn new(mean: N, std_dev: N) -> Result, Error> { - if !(std_dev >= N::zero()) { + if !(std_dev >= N::from(0.0)) { return Err(Error::StdDevTooSmall); } Ok(Normal { @@ -169,7 +168,7 @@ where StandardNormal: Distribution /// and standard deviation of the logarithm of the distribution. #[inline] pub fn new(mean: N, std_dev: N) -> Result, Error> { - if !(std_dev >= N::zero()) { + if !(std_dev >= N::from(0.0)) { return Err(Error::StdDevTooSmall); } Ok(LogNormal { norm: Normal::new(mean, std_dev).unwrap() }) diff --git a/rand_distr/src/utils.rs b/rand_distr/src/utils.rs index fcd81a472bb..f7055944302 100644 --- a/rand_distr/src/utils.rs +++ b/rand_distr/src/utils.rs @@ -10,8 +10,50 @@ use rand::Rng; use crate::ziggurat_tables; - use rand::distributions::hidden_export::IntoFloat; +use core::{cmp, ops}; + +/// Trait for floating-point scalar types +/// +/// This allows many distributions to work with `f32` or `f64` parameters and is +/// potentially extensible. Note however that the `Exp1` and `StandardNormal` +/// distributions are implemented exclusively for `f32` and `f64`. +/// +/// The bounds and methods are based purely on internal +/// requirements, and will change as needed. +pub trait Float: Copy + Sized + cmp::PartialOrd + + ops::Add + + ops::Sub + + ops::Mul + + ops::Div +{ + /// Support approximate representation of a f64 value + fn from(x: f64) -> Self; + /// Take the exponential of self + fn exp(self) -> Self; + /// Take the natural logarithm of self + fn ln(self) -> Self; + /// Take square root of self + fn sqrt(self) -> Self; + /// Take self to a floating-point power + fn powf(self, power: Self) -> Self; +} + +impl Float for f32 { + fn from(x: f64) -> Self { x as f32 } + fn exp(self) -> Self { self.exp() } + fn ln(self) -> Self { self.ln() } + fn sqrt(self) -> Self { self.sqrt() } + fn powf(self, power: Self) -> Self { self.powf(power) } +} + +impl Float for f64 { + fn from(x: f64) -> Self { x } + fn exp(self) -> Self { self.exp() } + fn ln(self) -> Self { self.ln() } + fn sqrt(self) -> Self { self.sqrt() } + fn powf(self, power: Self) -> Self { self.powf(power) } +} /// Calculates ln(gamma(x)) (natural logarithm of the gamma /// function) using the Lanczos approximation. @@ -26,7 +68,7 @@ use rand::distributions::hidden_export::IntoFloat; /// `Ag(z)` is an infinite series with coefficients that can be calculated /// ahead of time - we use just the first 6 terms, which is good enough /// for most purposes. -pub fn log_gamma(x: f64) -> f64 { +pub(crate) fn log_gamma(x: f64) -> f64 { // precalculated 6 coefficients for the first 6 terms of the series let coefficients: [f64; 6] = [ 76.18009172947146, @@ -71,7 +113,7 @@ pub fn log_gamma(x: f64) -> f64 { // the perf improvement (25-50%) is definitely worth the extra code // size from force-inlining. #[inline(always)] -pub fn ziggurat( +pub(crate) fn ziggurat( rng: &mut R, symmetric: bool, x_tab: ziggurat_tables::ZigTable, From 49a26aa31427376fd3f809daf30fb9eea6c51fbe Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Fri, 10 May 2019 17:56:09 +0100 Subject: [PATCH 2/9] rand_distr: make Cauchy generic over Float type --- rand_distr/src/cauchy.rs | 30 +++++++++++++++++------------- rand_distr/src/utils.rs | 21 +++++++++++++++++++++ 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/rand_distr/src/cauchy.rs b/rand_distr/src/cauchy.rs index 5ad36ae5f4b..a59ef842d3d 100644 --- a/rand_distr/src/cauchy.rs +++ b/rand_distr/src/cauchy.rs @@ -10,8 +10,8 @@ //! The Cauchy distribution. use rand::Rng; -use crate::Distribution; -use std::f64::consts::PI; +use crate::{Distribution, Standard}; +use crate::utils::Float; /// The Cauchy distribution `Cauchy(median, scale)`. /// @@ -28,9 +28,9 @@ use std::f64::consts::PI; /// println!("{} is from a Cauchy(2, 5) distribution", v); /// ``` #[derive(Clone, Copy, Debug)] -pub struct Cauchy { - median: f64, - scale: f64 +pub struct Cauchy { + median: N, + scale: N, } /// Error type returned from `Cauchy::new`. @@ -40,11 +40,13 @@ pub enum Error { ScaleTooSmall, } -impl Cauchy { +impl Cauchy +where Standard: Distribution +{ /// Construct a new `Cauchy` with the given shape parameters /// `median` the peak location and `scale` the scale factor. - pub fn new(median: f64, scale: f64) -> Result { - if !(scale > 0.0) { + pub fn new(median: N, scale: N) -> Result, Error> { + if !(scale > N::from(0.0)) { return Err(Error::ScaleTooSmall); } Ok(Cauchy { @@ -54,13 +56,15 @@ impl Cauchy { } } -impl Distribution for Cauchy { - fn sample(&self, rng: &mut R) -> f64 { +impl Distribution for Cauchy +where Standard: Distribution +{ + fn sample(&self, rng: &mut R) -> N { // sample from [0, 1) - let x = rng.gen::(); + let x = Standard.sample(rng); // get standard cauchy random number // note that π/2 is not exactly representable, even if x=0.5 the result is finite - let comp_dev = (PI * x).tan(); + let comp_dev = (N::pi() * x).tan(); // shift and scale according to parameters let result = self.median + self.scale * comp_dev; result @@ -99,7 +103,7 @@ mod test { fn test_cauchy_mean() { let cauchy = Cauchy::new(10.0, 5.0).unwrap(); let mut rng = crate::test::rng(123); - let mut sum = 0.0; + let mut sum = 0.0f64; for _ in 0..1000 { sum += cauchy.sample(&mut rng); } diff --git a/rand_distr/src/utils.rs b/rand_distr/src/utils.rs index f7055944302..36714890a01 100644 --- a/rand_distr/src/utils.rs +++ b/rand_distr/src/utils.rs @@ -27,8 +27,14 @@ pub trait Float: Copy + Sized + cmp::PartialOrd + ops::Mul + ops::Div { + /// The constant π + fn pi() -> Self; /// Support approximate representation of a f64 value fn from(x: f64) -> Self; + + /// Take the absolute value of self + fn abs(self) -> Self; + /// Take the exponential of self fn exp(self) -> Self; /// Take the natural logarithm of self @@ -37,22 +43,37 @@ pub trait Float: Copy + Sized + cmp::PartialOrd fn sqrt(self) -> Self; /// Take self to a floating-point power fn powf(self, power: Self) -> Self; + + /// Take the tangent of self + fn tan(self) -> Self; } impl Float for f32 { + fn pi() -> Self { core::f32::consts::PI } fn from(x: f64) -> Self { x as f32 } + + fn abs(self) -> Self { self.abs() } + fn exp(self) -> Self { self.exp() } fn ln(self) -> Self { self.ln() } fn sqrt(self) -> Self { self.sqrt() } fn powf(self, power: Self) -> Self { self.powf(power) } + + fn tan(self) -> Self { self.tan() } } impl Float for f64 { + fn pi() -> Self { core::f64::consts::PI } fn from(x: f64) -> Self { x } + + fn abs(self) -> Self { self.abs() } + fn exp(self) -> Self { self.exp() } fn ln(self) -> Self { self.ln() } fn sqrt(self) -> Self { self.sqrt() } fn powf(self, power: Self) -> Self { self.powf(power) } + + fn tan(self) -> Self { self.tan() } } /// Calculates ln(gamma(x)) (natural logarithm of the gamma From 0290d42a494d433fd2a54bdf9bd1c75a90f24e53 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Sat, 11 May 2019 12:56:55 +0100 Subject: [PATCH 3/9] rand_distr: make Dirichlet distr generic over float types --- rand_distr/src/dirichlet.rs | 34 +++++++++++++++++++--------------- rand_distr/src/utils.rs | 1 + 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/rand_distr/src/dirichlet.rs b/rand_distr/src/dirichlet.rs index bb53cd7c79b..b4ef530c075 100644 --- a/rand_distr/src/dirichlet.rs +++ b/rand_distr/src/dirichlet.rs @@ -10,8 +10,8 @@ //! The dirichlet distribution. use rand::Rng; -use crate::Distribution; -use crate::gamma::Gamma; +use crate::{Distribution, Gamma, StandardNormal, Exp1, Open01}; +use crate::utils::Float; /// The dirichelet distribution `Dirichlet(alpha)`. /// @@ -30,9 +30,9 @@ use crate::gamma::Gamma; /// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples); /// ``` #[derive(Clone, Debug)] -pub struct Dirichlet { +pub struct Dirichlet { /// Concentration parameters (alpha) - alpha: Vec, + alpha: Vec, } /// Error type returned from `Dirchlet::new`. @@ -46,18 +46,20 @@ pub enum Error { SizeTooSmall, } -impl Dirichlet { +impl Dirichlet +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ /// Construct a new `Dirichlet` with the given alpha parameter `alpha`. /// /// Requires `alpha.len() >= 2`. #[inline] - pub fn new>>(alpha: V) -> Result { + pub fn new>>(alpha: V) -> Result, Error> { let a = alpha.into(); if a.len() < 2 { return Err(Error::AlphaTooShort); } for i in 0..a.len() { - if !(a[i] > 0.0) { + if !(a[i] > N::from(0.0)) { return Err(Error::AlphaTooSmall); } } @@ -69,8 +71,8 @@ impl Dirichlet { /// /// Requires `size >= 2`. #[inline] - pub fn new_with_size(alpha: f64, size: usize) -> Result { - if !(alpha > 0.0) { + pub fn new_with_size(alpha: N, size: usize) -> Result, Error> { + if !(alpha > N::from(0.0)) { return Err(Error::AlphaTooSmall); } if size < 2 { @@ -82,18 +84,20 @@ impl Dirichlet { } } -impl Distribution> for Dirichlet { - fn sample(&self, rng: &mut R) -> Vec { +impl Distribution> for Dirichlet +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ + fn sample(&self, rng: &mut R) -> Vec { let n = self.alpha.len(); - let mut samples = vec![0.0f64; n]; - let mut sum = 0.0f64; + let mut samples = vec![N::from(0.0); n]; + let mut sum = N::from(0.0); for i in 0..n { - let g = Gamma::new(self.alpha[i], 1.0).unwrap(); + let g = Gamma::new(self.alpha[i], N::from(1.0)).unwrap(); samples[i] = g.sample(rng); sum += samples[i]; } - let invacc = 1.0 / sum; + let invacc = N::from(1.0) / sum; for i in 0..n { samples[i] *= invacc; } diff --git a/rand_distr/src/utils.rs b/rand_distr/src/utils.rs index 36714890a01..5a1c774ac8d 100644 --- a/rand_distr/src/utils.rs +++ b/rand_distr/src/utils.rs @@ -26,6 +26,7 @@ pub trait Float: Copy + Sized + cmp::PartialOrd + ops::Sub + ops::Mul + ops::Div + + ops::AddAssign + ops::SubAssign + ops::MulAssign + ops::DivAssign { /// The constant π fn pi() -> Self; From 127ba16f8e070c683ae268b81e351d5827bdba97 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Sat, 11 May 2019 13:00:27 +0100 Subject: [PATCH 4/9] rand_distr: make Pareto generic over Float --- rand_distr/src/pareto.rs | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/rand_distr/src/pareto.rs b/rand_distr/src/pareto.rs index 59977940915..134aaaa5e05 100644 --- a/rand_distr/src/pareto.rs +++ b/rand_distr/src/pareto.rs @@ -10,6 +10,7 @@ use rand::Rng; use crate::{Distribution, OpenClosed01}; +use crate::utils::Float; /// Samples floating-point numbers according to the Pareto distribution /// @@ -22,9 +23,9 @@ use crate::{Distribution, OpenClosed01}; /// println!("{}", val); /// ``` #[derive(Clone, Copy, Debug)] -pub struct Pareto { - scale: f64, - inv_neg_shape: f64, +pub struct Pareto { + scale: N, + inv_neg_shape: N, } /// Error type returned from `Pareto::new`. @@ -36,25 +37,29 @@ pub enum Error { ShapeTooSmall, } -impl Pareto { +impl Pareto +where OpenClosed01: Distribution +{ /// Construct a new Pareto distribution with given `scale` and `shape`. /// /// In the literature, `scale` is commonly written as xm or k and /// `shape` is often written as α. - pub fn new(scale: f64, shape: f64) -> Result { - if !(scale > 0.0) { + pub fn new(scale: N, shape: N) -> Result, Error> { + if !(scale > N::from(0.0)) { return Err(Error::ScaleTooSmall); } - if !(shape > 0.0) { + if !(shape > N::from(0.0)) { return Err(Error::ShapeTooSmall); } - Ok(Pareto { scale, inv_neg_shape: -1.0 / shape }) + Ok(Pareto { scale, inv_neg_shape: N::from(-1.0) / shape }) } } -impl Distribution for Pareto { - fn sample(&self, rng: &mut R) -> f64 { - let u: f64 = rng.sample(OpenClosed01); +impl Distribution for Pareto +where OpenClosed01: Distribution +{ + fn sample(&self, rng: &mut R) -> N { + let u: N = OpenClosed01.sample(rng); self.scale * u.powf(self.inv_neg_shape) } } From 5487157babe7f2b3956bb305c2e04070ca324d53 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Sat, 11 May 2019 14:38:51 +0100 Subject: [PATCH 5/9] rand_distr: make ChiSquared, FisherF, StudentT, Beta and Pert distributions generic over Float --- rand_distr/src/gamma.rs | 94 ++++++++++++++++++++++++----------------- rand_distr/src/pert.rs | 35 ++++++++------- 2 files changed, 75 insertions(+), 54 deletions(-) diff --git a/rand_distr/src/gamma.rs b/rand_distr/src/gamma.rs index 2dcbebe09cb..4018361648e 100644 --- a/rand_distr/src/gamma.rs +++ b/rand_distr/src/gamma.rs @@ -213,8 +213,8 @@ where StandardNormal: Distribution, Open01: Distribution /// println!("{} is from a χ²(11) distribution", v) /// ``` #[derive(Clone, Copy, Debug)] -pub struct ChiSquared { - repr: ChiSquaredRepr, +pub struct ChiSquared { + repr: ChiSquaredRepr, } /// Error type returned from `ChiSquared::new` and `StudentT::new`. @@ -225,35 +225,39 @@ pub enum ChiSquaredError { } #[derive(Clone, Copy, Debug)] -enum ChiSquaredRepr { +enum ChiSquaredRepr { // k == 1, Gamma(alpha, ..) is particularly slow for alpha < 1, // e.g. when alpha = 1/2 as it would be for this case, so special- // casing and using the definition of N(0,1)^2 is faster. DoFExactlyOne, - DoFAnythingElse(Gamma), + DoFAnythingElse(Gamma), } -impl ChiSquared { +impl ChiSquared +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ /// Create a new chi-squared distribution with degrees-of-freedom /// `k`. - pub fn new(k: f64) -> Result { - let repr = if k == 1.0 { + pub fn new(k: N) -> Result, ChiSquaredError> { + let repr = if k == N::from(1.0) { DoFExactlyOne } else { - if !(0.5 * k > 0.0) { + if !(N::from(0.5) * k > N::from(0.0)) { return Err(ChiSquaredError::DoFTooSmall); } - DoFAnythingElse(Gamma::new(0.5 * k, 2.0).unwrap()) + DoFAnythingElse(Gamma::new(N::from(0.5) * k, N::from(2.0)).unwrap()) }; Ok(ChiSquared { repr }) } } -impl Distribution for ChiSquared { - fn sample(&self, rng: &mut R) -> f64 { +impl Distribution for ChiSquared +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ + fn sample(&self, rng: &mut R) -> N { match self.repr { DoFExactlyOne => { // k == 1 => N(0,1)^2 - let norm: f64 = rng.sample(StandardNormal); + let norm: N = rng.sample(StandardNormal); norm * norm } DoFAnythingElse(ref g) => g.sample(rng) @@ -277,12 +281,12 @@ impl Distribution for ChiSquared { /// println!("{} is from an F(2, 32) distribution", v) /// ``` #[derive(Clone, Copy, Debug)] -pub struct FisherF { - numer: ChiSquared, - denom: ChiSquared, +pub struct FisherF { + numer: ChiSquared, + denom: ChiSquared, // denom_dof / numer_dof so that this can just be a straight // multiplication, rather than a division. - dof_ratio: f64, + dof_ratio: N, } /// Error type returned from `FisherF::new`. @@ -294,13 +298,15 @@ pub enum FisherFError { NTooSmall, } -impl FisherF { +impl FisherF +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ /// Create a new `FisherF` distribution, with the given parameter. - pub fn new(m: f64, n: f64) -> Result { - if !(m > 0.0) { + pub fn new(m: N, n: N) -> Result, FisherFError> { + if !(m > N::from(0.0)) { return Err(FisherFError::MTooSmall); } - if !(n > 0.0) { + if !(n > N::from(0.0)) { return Err(FisherFError::NTooSmall); } @@ -311,8 +317,10 @@ impl FisherF { }) } } -impl Distribution for FisherF { - fn sample(&self, rng: &mut R) -> f64 { +impl Distribution for FisherF +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ + fn sample(&self, rng: &mut R) -> N { self.numer.sample(rng) / self.denom.sample(rng) * self.dof_ratio } } @@ -330,24 +338,28 @@ impl Distribution for FisherF { /// println!("{} is from a t(11) distribution", v) /// ``` #[derive(Clone, Copy, Debug)] -pub struct StudentT { - chi: ChiSquared, - dof: f64 +pub struct StudentT { + chi: ChiSquared, + dof: N } -impl StudentT { +impl StudentT +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ /// Create a new Student t distribution with `n` degrees of /// freedom. - pub fn new(n: f64) -> Result { + pub fn new(n: N) -> Result, ChiSquaredError> { Ok(StudentT { chi: ChiSquared::new(n)?, dof: n }) } } -impl Distribution for StudentT { - fn sample(&self, rng: &mut R) -> f64 { - let norm: f64 = rng.sample(StandardNormal); +impl Distribution for StudentT +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ + fn sample(&self, rng: &mut R) -> N { + let norm: N = rng.sample(StandardNormal); norm * (self.dof / self.chi.sample(rng)).sqrt() } } @@ -364,9 +376,9 @@ impl Distribution for StudentT { /// println!("{} is from a Beta(2, 5) distribution", v); /// ``` #[derive(Clone, Copy, Debug)] -pub struct Beta { - gamma_a: Gamma, - gamma_b: Gamma, +pub struct Beta { + gamma_a: Gamma, + gamma_b: Gamma, } /// Error type returned from `Beta::new`. @@ -378,21 +390,25 @@ pub enum BetaError { BetaTooSmall, } -impl Beta { +impl Beta +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ /// Construct an object representing the `Beta(alpha, beta)` /// distribution. - pub fn new(alpha: f64, beta: f64) -> Result { + pub fn new(alpha: N, beta: N) -> Result, BetaError> { Ok(Beta { - gamma_a: Gamma::new(alpha, 1.) + gamma_a: Gamma::new(alpha, N::from(1.)) .map_err(|_| BetaError::AlphaTooSmall)?, - gamma_b: Gamma::new(beta, 1.) + gamma_b: Gamma::new(beta, N::from(1.)) .map_err(|_| BetaError::BetaTooSmall)?, }) } } -impl Distribution for Beta { - fn sample(&self, rng: &mut R) -> f64 { +impl Distribution for Beta +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ + fn sample(&self, rng: &mut R) -> N { let x = self.gamma_a.sample(rng); let y = self.gamma_b.sample(rng); x / (x + y) diff --git a/rand_distr/src/pert.rs b/rand_distr/src/pert.rs index 13d7faceca3..25d4c3a824c 100644 --- a/rand_distr/src/pert.rs +++ b/rand_distr/src/pert.rs @@ -8,7 +8,8 @@ //! The PERT distribution. use rand::Rng; -use crate::{Distribution, Beta}; +use crate::{Distribution, Beta, StandardNormal, Exp1, Open01}; +use crate::utils::Float; /// The PERT distribution. /// @@ -27,10 +28,10 @@ use crate::{Distribution, Beta}; /// println!("{} is from a PERT distribution", v); /// ``` #[derive(Clone, Copy, Debug)] -pub struct Pert { - min: f64, - range: f64, - beta: Beta, +pub struct Pert { + min: N, + range: N, + beta: Beta, } /// Error type returned from [`Pert`] constructors. @@ -44,34 +45,36 @@ pub enum PertError { ShapeTooSmall, } -impl Pert { +impl Pert +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ /// Set up the PERT distribution with defined `min`, `max` and `mode`. /// /// This is equivalent to calling `Pert::new_shape` with `shape == 4.0`. #[inline] - pub fn new(min: f64, max: f64, mode: f64) -> Result { - Pert::new_with_shape(min, max, mode, 4.) + pub fn new(min: N, max: N, mode: N) -> Result, PertError> { + Pert::new_with_shape(min, max, mode, N::from(4.)) } /// Set up the PERT distribution with defined `min`, `max`, `mode` and /// `shape`. - pub fn new_with_shape(min: f64, max: f64, mode: f64, shape: f64) -> Result { + pub fn new_with_shape(min: N, max: N, mode: N, shape: N) -> Result, PertError> { if !(max > min) { return Err(PertError::RangeTooSmall); } if !(mode >= min && max >= mode) { return Err(PertError::ModeRange); } - if !(shape >= 0.) { + if !(shape >= N::from(0.)) { return Err(PertError::ShapeTooSmall); } let range = max - min; - let mu = (min + max + shape * mode) / (shape + 2.); + let mu = (min + max + shape * mode) / (shape + N::from(2.)); let v = if mu == mode { - shape * 0.5 + 1. + shape * N::from(0.5) + N::from(1.) } else { - (mu - min) * (2. * mode - min - max) + (mu - min) * (N::from(2.) * mode - min - max) / ((mode - mu) * (max - min)) }; let w = v * (max - mu) / (mu - min); @@ -80,9 +83,11 @@ impl Pert { } } -impl Distribution for Pert { +impl Distribution for Pert +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ #[inline] - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> N { self.beta.sample(rng) * self.range + self.min } } From 9f13fcf86f69ed1f175dd630e2a26392d7966486 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Mon, 13 May 2019 11:30:25 +0100 Subject: [PATCH 6/9] rand_distr: make Triangular generic over Float --- rand_distr/src/triangular.rs | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/rand_distr/src/triangular.rs b/rand_distr/src/triangular.rs index b735743b3ff..72d58c5f278 100644 --- a/rand_distr/src/triangular.rs +++ b/rand_distr/src/triangular.rs @@ -9,6 +9,7 @@ use rand::Rng; use crate::{Distribution, Standard}; +use crate::utils::Float; /// The triangular distribution. /// @@ -28,10 +29,10 @@ use crate::{Distribution, Standard}; /// println!("{} is from a triangular distribution", v); /// ``` #[derive(Clone, Copy, Debug)] -pub struct Triangular { - min: f64, - max: f64, - mode: f64, +pub struct Triangular { + min: N, + max: N, + mode: N, } /// Error type returned from [`Triangular::new`]. @@ -43,10 +44,12 @@ pub enum TriangularError { ModeRange, } -impl Triangular { +impl Triangular +where Standard: Distribution +{ /// Set up the Triangular distribution with defined `min`, `max` and `mode`. #[inline] - pub fn new(min: f64, max: f64, mode: f64) -> Result { + pub fn new(min: N, max: N, mode: N) -> Result, TriangularError> { if !(max >= min) { return Err(TriangularError::RangeTooSmall); } @@ -57,10 +60,12 @@ impl Triangular { } } -impl Distribution for Triangular { +impl Distribution for Triangular +where Standard: Distribution +{ #[inline] - fn sample(&self, rng: &mut R) -> f64 { - let f: f64 = rng.sample(Standard); + fn sample(&self, rng: &mut R) -> N { + let f: N = rng.sample(Standard); let diff_mode_min = self.mode - self.min; let range = self.max - self.min; let f_range = f * range; From db795b5e97eaa764616c7602f6b0cd18d24027ab Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Mon, 13 May 2019 11:38:47 +0100 Subject: [PATCH 7/9] rand_distr: make UnitCircle and UnitSphereSurface generic over Float --- rand_distr/src/lib.rs | 2 +- rand_distr/src/unit_circle.rs | 13 +++++++------ rand_distr/src/unit_sphere.rs | 15 ++++++++------- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index 32de237814c..0e333e7c9d6 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -59,7 +59,7 @@ //! - [`UnitCircle`] distribution pub use rand::distributions::{Distribution, DistIter, Standard, - Alphanumeric, Uniform, OpenClosed01, Open01, Bernoulli, weighted}; + Alphanumeric, Uniform, OpenClosed01, Open01, Bernoulli, uniform, weighted}; pub use self::unit_sphere::UnitSphereSurface; pub use self::unit_circle::UnitCircle; diff --git a/rand_distr/src/unit_circle.rs b/rand_distr/src/unit_circle.rs index 948cbccf68b..65ee234911b 100644 --- a/rand_distr/src/unit_circle.rs +++ b/rand_distr/src/unit_circle.rs @@ -7,7 +7,8 @@ // except according to those terms. use rand::Rng; -use crate::{Distribution, Uniform}; +use crate::{Distribution, Uniform, uniform::SampleUniform}; +use crate::utils::Float; /// Samples uniformly from the edge of the unit circle in two dimensions. /// @@ -30,10 +31,10 @@ use crate::{Distribution, Uniform}; #[derive(Clone, Copy, Debug)] pub struct UnitCircle; -impl Distribution<[f64; 2]> for UnitCircle { +impl Distribution<[N; 2]> for UnitCircle { #[inline] - fn sample(&self, rng: &mut R) -> [f64; 2] { - let uniform = Uniform::new(-1., 1.); + fn sample(&self, rng: &mut R) -> [N; 2] { + let uniform = Uniform::new(N::from(-1.), N::from(1.)); let mut x1; let mut x2; let mut sum; @@ -41,12 +42,12 @@ impl Distribution<[f64; 2]> for UnitCircle { x1 = uniform.sample(rng); x2 = uniform.sample(rng); sum = x1*x1 + x2*x2; - if sum < 1. { + if sum < N::from(1.) { break; } } let diff = x1*x1 - x2*x2; - [diff / sum, 2.*x1*x2 / sum] + [diff / sum, N::from(2.)*x1*x2 / sum] } } diff --git a/rand_distr/src/unit_sphere.rs b/rand_distr/src/unit_sphere.rs index 1680e9da7e1..f7d453c092d 100644 --- a/rand_distr/src/unit_sphere.rs +++ b/rand_distr/src/unit_sphere.rs @@ -7,7 +7,8 @@ // except according to those terms. use rand::Rng; -use crate::{Distribution, Uniform}; +use crate::{Distribution, Uniform, uniform::SampleUniform}; +use crate::utils::Float; /// Samples uniformly from the surface of the unit sphere in three dimensions. /// @@ -29,18 +30,18 @@ use crate::{Distribution, Uniform}; #[derive(Clone, Copy, Debug)] pub struct UnitSphereSurface; -impl Distribution<[f64; 3]> for UnitSphereSurface { +impl Distribution<[N; 3]> for UnitSphereSurface { #[inline] - fn sample(&self, rng: &mut R) -> [f64; 3] { - let uniform = Uniform::new(-1., 1.); + fn sample(&self, rng: &mut R) -> [N; 3] { + let uniform = Uniform::new(N::from(-1.), N::from(1.)); loop { let (x1, x2) = (uniform.sample(rng), uniform.sample(rng)); let sum = x1*x1 + x2*x2; - if sum >= 1. { + if sum >= N::from(1.) { continue; } - let factor = 2. * (1.0_f64 - sum).sqrt(); - return [x1 * factor, x2 * factor, 1. - 2.*sum]; + let factor = N::from(2.) * (N::from(1.0) - sum).sqrt(); + return [x1 * factor, x2 * factor, N::from(1.) - N::from(2.)*sum]; } } } From ef5961357b265d1f26868980e9880ec0c57538ce Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Mon, 13 May 2019 11:41:33 +0100 Subject: [PATCH 8/9] rand_distr: make Weibull generic over Float --- rand_distr/src/utils.rs | 1 + rand_distr/src/weibull.rs | 27 ++++++++++++++++----------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/rand_distr/src/utils.rs b/rand_distr/src/utils.rs index 5a1c774ac8d..e5f107f23e5 100644 --- a/rand_distr/src/utils.rs +++ b/rand_distr/src/utils.rs @@ -22,6 +22,7 @@ use core::{cmp, ops}; /// The bounds and methods are based purely on internal /// requirements, and will change as needed. pub trait Float: Copy + Sized + cmp::PartialOrd + + ops::Neg + ops::Add + ops::Sub + ops::Mul diff --git a/rand_distr/src/weibull.rs b/rand_distr/src/weibull.rs index 8d3c66d8919..5a05da92dc5 100644 --- a/rand_distr/src/weibull.rs +++ b/rand_distr/src/weibull.rs @@ -10,6 +10,7 @@ use rand::Rng; use crate::{Distribution, OpenClosed01}; +use crate::utils::Float; /// Samples floating-point numbers according to the Weibull distribution /// @@ -22,9 +23,9 @@ use crate::{Distribution, OpenClosed01}; /// println!("{}", val); /// ``` #[derive(Clone, Copy, Debug)] -pub struct Weibull { - inv_shape: f64, - scale: f64, +pub struct Weibull { + inv_shape: N, + scale: N, } /// Error type returned from `Weibull::new`. @@ -36,22 +37,26 @@ pub enum Error { ShapeTooSmall, } -impl Weibull { +impl Weibull +where OpenClosed01: Distribution +{ /// Construct a new `Weibull` distribution with given `scale` and `shape`. - pub fn new(scale: f64, shape: f64) -> Result { - if !(scale > 0.0) { + pub fn new(scale: N, shape: N) -> Result, Error> { + if !(scale > N::from(0.0)) { return Err(Error::ScaleTooSmall); } - if !(shape > 0.0) { + if !(shape > N::from(0.0)) { return Err(Error::ShapeTooSmall); } - Ok(Weibull { inv_shape: 1./shape, scale }) + Ok(Weibull { inv_shape: N::from(1.)/shape, scale }) } } -impl Distribution for Weibull { - fn sample(&self, rng: &mut R) -> f64 { - let x: f64 = rng.sample(OpenClosed01); +impl Distribution for Weibull +where OpenClosed01: Distribution +{ + fn sample(&self, rng: &mut R) -> N { + let x: N = rng.sample(OpenClosed01); self.scale * (-x.ln()).powf(self.inv_shape) } } From b330c210014eed72ac1b859e8ef8525efc6ca2b9 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Tue, 14 May 2019 12:54:26 +0100 Subject: [PATCH 9/9] rand_distr: fix unit circle and sphere type specifications These are required as a side-effect of the previous generalisation. --- rand_distr/src/unit_circle.rs | 18 +++++++++++++----- rand_distr/src/unit_sphere.rs | 21 +++++++++++++-------- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/rand_distr/src/unit_circle.rs b/rand_distr/src/unit_circle.rs index 65ee234911b..55cfe4ff60c 100644 --- a/rand_distr/src/unit_circle.rs +++ b/rand_distr/src/unit_circle.rs @@ -20,7 +20,7 @@ use crate::utils::Float; /// ``` /// use rand_distr::{UnitCircle, Distribution}; /// -/// let v = UnitCircle.sample(&mut rand::thread_rng()); +/// let v: [f64; 2] = UnitCircle.sample(&mut rand::thread_rng()); /// println!("{:?} is from the unit circle.", v) /// ``` /// @@ -76,7 +76,7 @@ mod tests { fn norm() { let mut rng = crate::test::rng(1); for _ in 0..1000 { - let x = UnitCircle.sample(&mut rng); + let x: [f64; 2] = UnitCircle.sample(&mut rng); assert_almost_eq!(x[0]*x[0] + x[1]*x[1], 1., 1e-15); } } @@ -84,8 +84,16 @@ mod tests { #[test] fn value_stability() { let mut rng = crate::test::rng(2); - assert_eq!(UnitCircle.sample(&mut rng), [-0.8032118336637037, 0.5956935036263119]); - assert_eq!(UnitCircle.sample(&mut rng), [-0.4742919588505423, -0.880367615130018]); - assert_eq!(UnitCircle.sample(&mut rng), [0.9297328981467168, 0.368234623716601]); + let expected = [ + [-0.8032118336637037, 0.5956935036263119], + [-0.4742919588505423, -0.880367615130018], + [0.9297328981467168, 0.368234623716601], + ]; + let samples: [[f64; 2]; 3] = [ + UnitCircle.sample(&mut rng), + UnitCircle.sample(&mut rng), + UnitCircle.sample(&mut rng), + ]; + assert_eq!(samples, expected); } } diff --git a/rand_distr/src/unit_sphere.rs b/rand_distr/src/unit_sphere.rs index f7d453c092d..dad6f5f7e11 100644 --- a/rand_distr/src/unit_sphere.rs +++ b/rand_distr/src/unit_sphere.rs @@ -20,7 +20,7 @@ use crate::utils::Float; /// ``` /// use rand_distr::{UnitSphereSurface, Distribution}; /// -/// let v = UnitSphereSurface.sample(&mut rand::thread_rng()); +/// let v: [f64; 3] = UnitSphereSurface.sample(&mut rand::thread_rng()); /// println!("{:?} is from the unit sphere surface.", v) /// ``` /// @@ -71,7 +71,7 @@ mod tests { fn norm() { let mut rng = crate::test::rng(1); for _ in 0..1000 { - let x = UnitSphereSurface.sample(&mut rng); + let x: [f64; 3] = UnitSphereSurface.sample(&mut rng); assert_almost_eq!(x[0]*x[0] + x[1]*x[1] + x[2]*x[2], 1., 1e-15); } } @@ -79,11 +79,16 @@ mod tests { #[test] fn value_stability() { let mut rng = crate::test::rng(2); - assert_eq!(UnitSphereSurface.sample(&mut rng), - [-0.24950027180862533, -0.7552572587896719, 0.6060825747478084]); - assert_eq!(UnitSphereSurface.sample(&mut rng), - [0.47604534507233487, -0.797200864987207, -0.3712837328763685]); - assert_eq!(UnitSphereSurface.sample(&mut rng), - [0.9795722330927367, 0.18692349236651176, 0.07414747571708524]); + let expected = [ + [-0.24950027180862533, -0.7552572587896719, 0.6060825747478084], + [0.47604534507233487, -0.797200864987207, -0.3712837328763685], + [0.9795722330927367, 0.18692349236651176, 0.07414747571708524], + ]; + let samples: [[f64; 3]; 3] = [ + UnitSphereSurface.sample(&mut rng), + UnitSphereSurface.sample(&mut rng), + UnitSphereSurface.sample(&mut rng), + ]; + assert_eq!(samples, expected); } }