diff --git a/rand_distr/Cargo.toml b/rand_distr/Cargo.toml index 526ea0bc124..d6e4931ce34 100644 --- a/rand_distr/Cargo.toml +++ b/rand_distr/Cargo.toml @@ -20,4 +20,3 @@ appveyor = { repository = "rust-random/rand" } [dependencies] rand = { path = "..", version = ">=0.5, <=0.7" } -num-traits = "0.2" diff --git a/rand_distr/src/cauchy.rs b/rand_distr/src/cauchy.rs index 5ad36ae5f4b..a59ef842d3d 100644 --- a/rand_distr/src/cauchy.rs +++ b/rand_distr/src/cauchy.rs @@ -10,8 +10,8 @@ //! The Cauchy distribution. use rand::Rng; -use crate::Distribution; -use std::f64::consts::PI; +use crate::{Distribution, Standard}; +use crate::utils::Float; /// The Cauchy distribution `Cauchy(median, scale)`. /// @@ -28,9 +28,9 @@ use std::f64::consts::PI; /// println!("{} is from a Cauchy(2, 5) distribution", v); /// ``` #[derive(Clone, Copy, Debug)] -pub struct Cauchy { - median: f64, - scale: f64 +pub struct Cauchy { + median: N, + scale: N, } /// Error type returned from `Cauchy::new`. @@ -40,11 +40,13 @@ pub enum Error { ScaleTooSmall, } -impl Cauchy { +impl Cauchy +where Standard: Distribution +{ /// Construct a new `Cauchy` with the given shape parameters /// `median` the peak location and `scale` the scale factor. - pub fn new(median: f64, scale: f64) -> Result { - if !(scale > 0.0) { + pub fn new(median: N, scale: N) -> Result, Error> { + if !(scale > N::from(0.0)) { return Err(Error::ScaleTooSmall); } Ok(Cauchy { @@ -54,13 +56,15 @@ impl Cauchy { } } -impl Distribution for Cauchy { - fn sample(&self, rng: &mut R) -> f64 { +impl Distribution for Cauchy +where Standard: Distribution +{ + fn sample(&self, rng: &mut R) -> N { // sample from [0, 1) - let x = rng.gen::(); + let x = Standard.sample(rng); // get standard cauchy random number // note that π/2 is not exactly representable, even if x=0.5 the result is finite - let comp_dev = (PI * x).tan(); + let comp_dev = (N::pi() * x).tan(); // shift and scale according to parameters let result = self.median + self.scale * comp_dev; result @@ -99,7 +103,7 @@ mod test { fn test_cauchy_mean() { let cauchy = Cauchy::new(10.0, 5.0).unwrap(); let mut rng = crate::test::rng(123); - let mut sum = 0.0; + let mut sum = 0.0f64; for _ in 0..1000 { sum += cauchy.sample(&mut rng); } diff --git a/rand_distr/src/dirichlet.rs b/rand_distr/src/dirichlet.rs index bb53cd7c79b..b4ef530c075 100644 --- a/rand_distr/src/dirichlet.rs +++ b/rand_distr/src/dirichlet.rs @@ -10,8 +10,8 @@ //! The dirichlet distribution. use rand::Rng; -use crate::Distribution; -use crate::gamma::Gamma; +use crate::{Distribution, Gamma, StandardNormal, Exp1, Open01}; +use crate::utils::Float; /// The dirichelet distribution `Dirichlet(alpha)`. /// @@ -30,9 +30,9 @@ use crate::gamma::Gamma; /// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples); /// ``` #[derive(Clone, Debug)] -pub struct Dirichlet { +pub struct Dirichlet { /// Concentration parameters (alpha) - alpha: Vec, + alpha: Vec, } /// Error type returned from `Dirchlet::new`. @@ -46,18 +46,20 @@ pub enum Error { SizeTooSmall, } -impl Dirichlet { +impl Dirichlet +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ /// Construct a new `Dirichlet` with the given alpha parameter `alpha`. /// /// Requires `alpha.len() >= 2`. #[inline] - pub fn new>>(alpha: V) -> Result { + pub fn new>>(alpha: V) -> Result, Error> { let a = alpha.into(); if a.len() < 2 { return Err(Error::AlphaTooShort); } for i in 0..a.len() { - if !(a[i] > 0.0) { + if !(a[i] > N::from(0.0)) { return Err(Error::AlphaTooSmall); } } @@ -69,8 +71,8 @@ impl Dirichlet { /// /// Requires `size >= 2`. #[inline] - pub fn new_with_size(alpha: f64, size: usize) -> Result { - if !(alpha > 0.0) { + pub fn new_with_size(alpha: N, size: usize) -> Result, Error> { + if !(alpha > N::from(0.0)) { return Err(Error::AlphaTooSmall); } if size < 2 { @@ -82,18 +84,20 @@ impl Dirichlet { } } -impl Distribution> for Dirichlet { - fn sample(&self, rng: &mut R) -> Vec { +impl Distribution> for Dirichlet +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ + fn sample(&self, rng: &mut R) -> Vec { let n = self.alpha.len(); - let mut samples = vec![0.0f64; n]; - let mut sum = 0.0f64; + let mut samples = vec![N::from(0.0); n]; + let mut sum = N::from(0.0); for i in 0..n { - let g = Gamma::new(self.alpha[i], 1.0).unwrap(); + let g = Gamma::new(self.alpha[i], N::from(1.0)).unwrap(); samples[i] = g.sample(rng); sum += samples[i]; } - let invacc = 1.0 / sum; + let invacc = N::from(1.0) / sum; for i in 0..n { samples[i] *= invacc; } diff --git a/rand_distr/src/exponential.rs b/rand_distr/src/exponential.rs index bd832f8279d..a4f76194602 100644 --- a/rand_distr/src/exponential.rs +++ b/rand_distr/src/exponential.rs @@ -11,8 +11,7 @@ use rand::Rng; use crate::{ziggurat_tables, Distribution}; -use crate::utils::ziggurat; -use num_traits::Float; +use crate::utils::{ziggurat, Float}; /// Samples floating-point numbers according to the exponential distribution, /// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or @@ -105,10 +104,10 @@ where Exp1: Distribution /// `lambda`. #[inline] pub fn new(lambda: N) -> Result, Error> { - if !(lambda > N::zero()) { + if !(lambda > N::from(0.0)) { return Err(Error::LambdaTooSmall); } - Ok(Exp { lambda_inverse: N::one() / lambda }) + Ok(Exp { lambda_inverse: N::from(1.0) / lambda }) } } diff --git a/rand_distr/src/gamma.rs b/rand_distr/src/gamma.rs index 6035f61107d..4018361648e 100644 --- a/rand_distr/src/gamma.rs +++ b/rand_distr/src/gamma.rs @@ -15,7 +15,7 @@ use self::ChiSquaredRepr::*; use rand::Rng; use crate::normal::StandardNormal; use crate::{Distribution, Exp1, Exp, Open01}; -use num_traits::Float; +use crate::utils::Float; /// The Gamma distribution `Gamma(shape, scale)` distribution. /// @@ -108,16 +108,16 @@ where StandardNormal: Distribution, Exp1: Distribution, Open01: Distributi /// distribution. #[inline] pub fn new(shape: N, scale: N) -> Result, Error> { - if !(shape > N::zero()) { + if !(shape > N::from(0.0)) { return Err(Error::ShapeTooSmall); } - if !(scale > N::zero()) { + if !(scale > N::from(0.0)) { return Err(Error::ScaleTooSmall); } - let repr = if shape == N::one() { - One(Exp::new(N::one() / scale).map_err(|_| Error::ScaleTooLarge)?) - } else if shape < N::one() { + let repr = if shape == N::from(1.0) { + One(Exp::new(N::from(1.0) / scale).map_err(|_| Error::ScaleTooLarge)?) + } else if shape < N::from(1.0) { Small(GammaSmallShape::new_raw(shape, scale)) } else { Large(GammaLargeShape::new_raw(shape, scale)) @@ -131,8 +131,8 @@ where StandardNormal: Distribution, Open01: Distribution { fn new_raw(shape: N, scale: N) -> GammaSmallShape { GammaSmallShape { - inv_shape: N::one() / shape, - large_shape: GammaLargeShape::new_raw(shape + N::one(), scale) + inv_shape: N::from(1.0) / shape, + large_shape: GammaLargeShape::new_raw(shape + N::from(1.0), scale) } } } @@ -141,10 +141,10 @@ impl GammaLargeShape where StandardNormal: Distribution, Open01: Distribution { fn new_raw(shape: N, scale: N) -> GammaLargeShape { - let d = shape - N::from(1. / 3.).unwrap(); + let d = shape - N::from(1. / 3.); GammaLargeShape { scale, - c: N::one() / (N::from(9.).unwrap() * d).sqrt(), + c: N::from(1.0) / (N::from(9.) * d).sqrt(), d } } @@ -177,8 +177,8 @@ where StandardNormal: Distribution, Open01: Distribution // Marsaglia & Tsang method, 2000 loop { let x: N = rng.sample(StandardNormal); - let v_cbrt = N::one() + self.c * x; - if v_cbrt <= N::zero() { // a^3 <= 0 iff a <= 0 + let v_cbrt = N::from(1.0) + self.c * x; + if v_cbrt <= N::from(0.0) { // a^3 <= 0 iff a <= 0 continue } @@ -186,8 +186,8 @@ where StandardNormal: Distribution, Open01: Distribution let u: N = rng.sample(Open01); let x_sqr = x * x; - if u < N::one() - N::from(0.0331).unwrap() * x_sqr * x_sqr || - u.ln() < N::from(0.5).unwrap() * x_sqr + self.d * (N::one() - v + v.ln()) + if u < N::from(1.0) - N::from(0.0331) * x_sqr * x_sqr || + u.ln() < N::from(0.5) * x_sqr + self.d * (N::from(1.0) - v + v.ln()) { return self.d * v * self.scale } @@ -213,8 +213,8 @@ where StandardNormal: Distribution, Open01: Distribution /// println!("{} is from a χ²(11) distribution", v) /// ``` #[derive(Clone, Copy, Debug)] -pub struct ChiSquared { - repr: ChiSquaredRepr, +pub struct ChiSquared { + repr: ChiSquaredRepr, } /// Error type returned from `ChiSquared::new` and `StudentT::new`. @@ -225,35 +225,39 @@ pub enum ChiSquaredError { } #[derive(Clone, Copy, Debug)] -enum ChiSquaredRepr { +enum ChiSquaredRepr { // k == 1, Gamma(alpha, ..) is particularly slow for alpha < 1, // e.g. when alpha = 1/2 as it would be for this case, so special- // casing and using the definition of N(0,1)^2 is faster. DoFExactlyOne, - DoFAnythingElse(Gamma), + DoFAnythingElse(Gamma), } -impl ChiSquared { +impl ChiSquared +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ /// Create a new chi-squared distribution with degrees-of-freedom /// `k`. - pub fn new(k: f64) -> Result { - let repr = if k == 1.0 { + pub fn new(k: N) -> Result, ChiSquaredError> { + let repr = if k == N::from(1.0) { DoFExactlyOne } else { - if !(0.5 * k > 0.0) { + if !(N::from(0.5) * k > N::from(0.0)) { return Err(ChiSquaredError::DoFTooSmall); } - DoFAnythingElse(Gamma::new(0.5 * k, 2.0).unwrap()) + DoFAnythingElse(Gamma::new(N::from(0.5) * k, N::from(2.0)).unwrap()) }; Ok(ChiSquared { repr }) } } -impl Distribution for ChiSquared { - fn sample(&self, rng: &mut R) -> f64 { +impl Distribution for ChiSquared +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ + fn sample(&self, rng: &mut R) -> N { match self.repr { DoFExactlyOne => { // k == 1 => N(0,1)^2 - let norm: f64 = rng.sample(StandardNormal); + let norm: N = rng.sample(StandardNormal); norm * norm } DoFAnythingElse(ref g) => g.sample(rng) @@ -277,12 +281,12 @@ impl Distribution for ChiSquared { /// println!("{} is from an F(2, 32) distribution", v) /// ``` #[derive(Clone, Copy, Debug)] -pub struct FisherF { - numer: ChiSquared, - denom: ChiSquared, +pub struct FisherF { + numer: ChiSquared, + denom: ChiSquared, // denom_dof / numer_dof so that this can just be a straight // multiplication, rather than a division. - dof_ratio: f64, + dof_ratio: N, } /// Error type returned from `FisherF::new`. @@ -294,13 +298,15 @@ pub enum FisherFError { NTooSmall, } -impl FisherF { +impl FisherF +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ /// Create a new `FisherF` distribution, with the given parameter. - pub fn new(m: f64, n: f64) -> Result { - if !(m > 0.0) { + pub fn new(m: N, n: N) -> Result, FisherFError> { + if !(m > N::from(0.0)) { return Err(FisherFError::MTooSmall); } - if !(n > 0.0) { + if !(n > N::from(0.0)) { return Err(FisherFError::NTooSmall); } @@ -311,8 +317,10 @@ impl FisherF { }) } } -impl Distribution for FisherF { - fn sample(&self, rng: &mut R) -> f64 { +impl Distribution for FisherF +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ + fn sample(&self, rng: &mut R) -> N { self.numer.sample(rng) / self.denom.sample(rng) * self.dof_ratio } } @@ -330,24 +338,28 @@ impl Distribution for FisherF { /// println!("{} is from a t(11) distribution", v) /// ``` #[derive(Clone, Copy, Debug)] -pub struct StudentT { - chi: ChiSquared, - dof: f64 +pub struct StudentT { + chi: ChiSquared, + dof: N } -impl StudentT { +impl StudentT +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ /// Create a new Student t distribution with `n` degrees of /// freedom. - pub fn new(n: f64) -> Result { + pub fn new(n: N) -> Result, ChiSquaredError> { Ok(StudentT { chi: ChiSquared::new(n)?, dof: n }) } } -impl Distribution for StudentT { - fn sample(&self, rng: &mut R) -> f64 { - let norm: f64 = rng.sample(StandardNormal); +impl Distribution for StudentT +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ + fn sample(&self, rng: &mut R) -> N { + let norm: N = rng.sample(StandardNormal); norm * (self.dof / self.chi.sample(rng)).sqrt() } } @@ -364,9 +376,9 @@ impl Distribution for StudentT { /// println!("{} is from a Beta(2, 5) distribution", v); /// ``` #[derive(Clone, Copy, Debug)] -pub struct Beta { - gamma_a: Gamma, - gamma_b: Gamma, +pub struct Beta { + gamma_a: Gamma, + gamma_b: Gamma, } /// Error type returned from `Beta::new`. @@ -378,21 +390,25 @@ pub enum BetaError { BetaTooSmall, } -impl Beta { +impl Beta +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ /// Construct an object representing the `Beta(alpha, beta)` /// distribution. - pub fn new(alpha: f64, beta: f64) -> Result { + pub fn new(alpha: N, beta: N) -> Result, BetaError> { Ok(Beta { - gamma_a: Gamma::new(alpha, 1.) + gamma_a: Gamma::new(alpha, N::from(1.)) .map_err(|_| BetaError::AlphaTooSmall)?, - gamma_b: Gamma::new(beta, 1.) + gamma_b: Gamma::new(beta, N::from(1.)) .map_err(|_| BetaError::BetaTooSmall)?, }) } } -impl Distribution for Beta { - fn sample(&self, rng: &mut R) -> f64 { +impl Distribution for Beta +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ + fn sample(&self, rng: &mut R) -> N { let x = self.gamma_a.sample(rng); let y = self.gamma_b.sample(rng); x / (x + y) diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index d172854098e..0e333e7c9d6 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -59,7 +59,7 @@ //! - [`UnitCircle`] distribution pub use rand::distributions::{Distribution, DistIter, Standard, - Alphanumeric, Uniform, OpenClosed01, Open01, Bernoulli, weighted}; + Alphanumeric, Uniform, OpenClosed01, Open01, Bernoulli, uniform, weighted}; pub use self::unit_sphere::UnitSphereSurface; pub use self::unit_circle::UnitCircle; @@ -75,6 +75,7 @@ pub use self::cauchy::{Cauchy, Error as CauchyError}; pub use self::dirichlet::{Dirichlet, Error as DirichletError}; pub use self::triangular::{Triangular, TriangularError}; pub use self::weibull::{Weibull, Error as WeibullError}; +pub use self::utils::Float; mod unit_sphere; mod unit_circle; diff --git a/rand_distr/src/normal.rs b/rand_distr/src/normal.rs index ea700e54d8a..036edc20569 100644 --- a/rand_distr/src/normal.rs +++ b/rand_distr/src/normal.rs @@ -11,8 +11,7 @@ use rand::Rng; use crate::{ziggurat_tables, Distribution, Open01}; -use crate::utils::ziggurat; -use num_traits::Float; +use crate::utils::{ziggurat, Float}; /// Samples floating-point numbers according to the normal distribution /// `N(0, 1)` (a.k.a. a standard normal, or Gaussian). This is equivalent to @@ -122,7 +121,7 @@ where StandardNormal: Distribution /// standard deviation. #[inline] pub fn new(mean: N, std_dev: N) -> Result, Error> { - if !(std_dev >= N::zero()) { + if !(std_dev >= N::from(0.0)) { return Err(Error::StdDevTooSmall); } Ok(Normal { @@ -169,7 +168,7 @@ where StandardNormal: Distribution /// and standard deviation of the logarithm of the distribution. #[inline] pub fn new(mean: N, std_dev: N) -> Result, Error> { - if !(std_dev >= N::zero()) { + if !(std_dev >= N::from(0.0)) { return Err(Error::StdDevTooSmall); } Ok(LogNormal { norm: Normal::new(mean, std_dev).unwrap() }) diff --git a/rand_distr/src/pareto.rs b/rand_distr/src/pareto.rs index 59977940915..134aaaa5e05 100644 --- a/rand_distr/src/pareto.rs +++ b/rand_distr/src/pareto.rs @@ -10,6 +10,7 @@ use rand::Rng; use crate::{Distribution, OpenClosed01}; +use crate::utils::Float; /// Samples floating-point numbers according to the Pareto distribution /// @@ -22,9 +23,9 @@ use crate::{Distribution, OpenClosed01}; /// println!("{}", val); /// ``` #[derive(Clone, Copy, Debug)] -pub struct Pareto { - scale: f64, - inv_neg_shape: f64, +pub struct Pareto { + scale: N, + inv_neg_shape: N, } /// Error type returned from `Pareto::new`. @@ -36,25 +37,29 @@ pub enum Error { ShapeTooSmall, } -impl Pareto { +impl Pareto +where OpenClosed01: Distribution +{ /// Construct a new Pareto distribution with given `scale` and `shape`. /// /// In the literature, `scale` is commonly written as xm or k and /// `shape` is often written as α. - pub fn new(scale: f64, shape: f64) -> Result { - if !(scale > 0.0) { + pub fn new(scale: N, shape: N) -> Result, Error> { + if !(scale > N::from(0.0)) { return Err(Error::ScaleTooSmall); } - if !(shape > 0.0) { + if !(shape > N::from(0.0)) { return Err(Error::ShapeTooSmall); } - Ok(Pareto { scale, inv_neg_shape: -1.0 / shape }) + Ok(Pareto { scale, inv_neg_shape: N::from(-1.0) / shape }) } } -impl Distribution for Pareto { - fn sample(&self, rng: &mut R) -> f64 { - let u: f64 = rng.sample(OpenClosed01); +impl Distribution for Pareto +where OpenClosed01: Distribution +{ + fn sample(&self, rng: &mut R) -> N { + let u: N = OpenClosed01.sample(rng); self.scale * u.powf(self.inv_neg_shape) } } diff --git a/rand_distr/src/pert.rs b/rand_distr/src/pert.rs index 13d7faceca3..25d4c3a824c 100644 --- a/rand_distr/src/pert.rs +++ b/rand_distr/src/pert.rs @@ -8,7 +8,8 @@ //! The PERT distribution. use rand::Rng; -use crate::{Distribution, Beta}; +use crate::{Distribution, Beta, StandardNormal, Exp1, Open01}; +use crate::utils::Float; /// The PERT distribution. /// @@ -27,10 +28,10 @@ use crate::{Distribution, Beta}; /// println!("{} is from a PERT distribution", v); /// ``` #[derive(Clone, Copy, Debug)] -pub struct Pert { - min: f64, - range: f64, - beta: Beta, +pub struct Pert { + min: N, + range: N, + beta: Beta, } /// Error type returned from [`Pert`] constructors. @@ -44,34 +45,36 @@ pub enum PertError { ShapeTooSmall, } -impl Pert { +impl Pert +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ /// Set up the PERT distribution with defined `min`, `max` and `mode`. /// /// This is equivalent to calling `Pert::new_shape` with `shape == 4.0`. #[inline] - pub fn new(min: f64, max: f64, mode: f64) -> Result { - Pert::new_with_shape(min, max, mode, 4.) + pub fn new(min: N, max: N, mode: N) -> Result, PertError> { + Pert::new_with_shape(min, max, mode, N::from(4.)) } /// Set up the PERT distribution with defined `min`, `max`, `mode` and /// `shape`. - pub fn new_with_shape(min: f64, max: f64, mode: f64, shape: f64) -> Result { + pub fn new_with_shape(min: N, max: N, mode: N, shape: N) -> Result, PertError> { if !(max > min) { return Err(PertError::RangeTooSmall); } if !(mode >= min && max >= mode) { return Err(PertError::ModeRange); } - if !(shape >= 0.) { + if !(shape >= N::from(0.)) { return Err(PertError::ShapeTooSmall); } let range = max - min; - let mu = (min + max + shape * mode) / (shape + 2.); + let mu = (min + max + shape * mode) / (shape + N::from(2.)); let v = if mu == mode { - shape * 0.5 + 1. + shape * N::from(0.5) + N::from(1.) } else { - (mu - min) * (2. * mode - min - max) + (mu - min) * (N::from(2.) * mode - min - max) / ((mode - mu) * (max - min)) }; let w = v * (max - mu) / (mu - min); @@ -80,9 +83,11 @@ impl Pert { } } -impl Distribution for Pert { +impl Distribution for Pert +where StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution +{ #[inline] - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> N { self.beta.sample(rng) * self.range + self.min } } diff --git a/rand_distr/src/triangular.rs b/rand_distr/src/triangular.rs index b735743b3ff..72d58c5f278 100644 --- a/rand_distr/src/triangular.rs +++ b/rand_distr/src/triangular.rs @@ -9,6 +9,7 @@ use rand::Rng; use crate::{Distribution, Standard}; +use crate::utils::Float; /// The triangular distribution. /// @@ -28,10 +29,10 @@ use crate::{Distribution, Standard}; /// println!("{} is from a triangular distribution", v); /// ``` #[derive(Clone, Copy, Debug)] -pub struct Triangular { - min: f64, - max: f64, - mode: f64, +pub struct Triangular { + min: N, + max: N, + mode: N, } /// Error type returned from [`Triangular::new`]. @@ -43,10 +44,12 @@ pub enum TriangularError { ModeRange, } -impl Triangular { +impl Triangular +where Standard: Distribution +{ /// Set up the Triangular distribution with defined `min`, `max` and `mode`. #[inline] - pub fn new(min: f64, max: f64, mode: f64) -> Result { + pub fn new(min: N, max: N, mode: N) -> Result, TriangularError> { if !(max >= min) { return Err(TriangularError::RangeTooSmall); } @@ -57,10 +60,12 @@ impl Triangular { } } -impl Distribution for Triangular { +impl Distribution for Triangular +where Standard: Distribution +{ #[inline] - fn sample(&self, rng: &mut R) -> f64 { - let f: f64 = rng.sample(Standard); + fn sample(&self, rng: &mut R) -> N { + let f: N = rng.sample(Standard); let diff_mode_min = self.mode - self.min; let range = self.max - self.min; let f_range = f * range; diff --git a/rand_distr/src/unit_circle.rs b/rand_distr/src/unit_circle.rs index 948cbccf68b..55cfe4ff60c 100644 --- a/rand_distr/src/unit_circle.rs +++ b/rand_distr/src/unit_circle.rs @@ -7,7 +7,8 @@ // except according to those terms. use rand::Rng; -use crate::{Distribution, Uniform}; +use crate::{Distribution, Uniform, uniform::SampleUniform}; +use crate::utils::Float; /// Samples uniformly from the edge of the unit circle in two dimensions. /// @@ -19,7 +20,7 @@ use crate::{Distribution, Uniform}; /// ``` /// use rand_distr::{UnitCircle, Distribution}; /// -/// let v = UnitCircle.sample(&mut rand::thread_rng()); +/// let v: [f64; 2] = UnitCircle.sample(&mut rand::thread_rng()); /// println!("{:?} is from the unit circle.", v) /// ``` /// @@ -30,10 +31,10 @@ use crate::{Distribution, Uniform}; #[derive(Clone, Copy, Debug)] pub struct UnitCircle; -impl Distribution<[f64; 2]> for UnitCircle { +impl Distribution<[N; 2]> for UnitCircle { #[inline] - fn sample(&self, rng: &mut R) -> [f64; 2] { - let uniform = Uniform::new(-1., 1.); + fn sample(&self, rng: &mut R) -> [N; 2] { + let uniform = Uniform::new(N::from(-1.), N::from(1.)); let mut x1; let mut x2; let mut sum; @@ -41,12 +42,12 @@ impl Distribution<[f64; 2]> for UnitCircle { x1 = uniform.sample(rng); x2 = uniform.sample(rng); sum = x1*x1 + x2*x2; - if sum < 1. { + if sum < N::from(1.) { break; } } let diff = x1*x1 - x2*x2; - [diff / sum, 2.*x1*x2 / sum] + [diff / sum, N::from(2.)*x1*x2 / sum] } } @@ -75,7 +76,7 @@ mod tests { fn norm() { let mut rng = crate::test::rng(1); for _ in 0..1000 { - let x = UnitCircle.sample(&mut rng); + let x: [f64; 2] = UnitCircle.sample(&mut rng); assert_almost_eq!(x[0]*x[0] + x[1]*x[1], 1., 1e-15); } } @@ -83,8 +84,16 @@ mod tests { #[test] fn value_stability() { let mut rng = crate::test::rng(2); - assert_eq!(UnitCircle.sample(&mut rng), [-0.8032118336637037, 0.5956935036263119]); - assert_eq!(UnitCircle.sample(&mut rng), [-0.4742919588505423, -0.880367615130018]); - assert_eq!(UnitCircle.sample(&mut rng), [0.9297328981467168, 0.368234623716601]); + let expected = [ + [-0.8032118336637037, 0.5956935036263119], + [-0.4742919588505423, -0.880367615130018], + [0.9297328981467168, 0.368234623716601], + ]; + let samples: [[f64; 2]; 3] = [ + UnitCircle.sample(&mut rng), + UnitCircle.sample(&mut rng), + UnitCircle.sample(&mut rng), + ]; + assert_eq!(samples, expected); } } diff --git a/rand_distr/src/unit_sphere.rs b/rand_distr/src/unit_sphere.rs index 1680e9da7e1..dad6f5f7e11 100644 --- a/rand_distr/src/unit_sphere.rs +++ b/rand_distr/src/unit_sphere.rs @@ -7,7 +7,8 @@ // except according to those terms. use rand::Rng; -use crate::{Distribution, Uniform}; +use crate::{Distribution, Uniform, uniform::SampleUniform}; +use crate::utils::Float; /// Samples uniformly from the surface of the unit sphere in three dimensions. /// @@ -19,7 +20,7 @@ use crate::{Distribution, Uniform}; /// ``` /// use rand_distr::{UnitSphereSurface, Distribution}; /// -/// let v = UnitSphereSurface.sample(&mut rand::thread_rng()); +/// let v: [f64; 3] = UnitSphereSurface.sample(&mut rand::thread_rng()); /// println!("{:?} is from the unit sphere surface.", v) /// ``` /// @@ -29,18 +30,18 @@ use crate::{Distribution, Uniform}; #[derive(Clone, Copy, Debug)] pub struct UnitSphereSurface; -impl Distribution<[f64; 3]> for UnitSphereSurface { +impl Distribution<[N; 3]> for UnitSphereSurface { #[inline] - fn sample(&self, rng: &mut R) -> [f64; 3] { - let uniform = Uniform::new(-1., 1.); + fn sample(&self, rng: &mut R) -> [N; 3] { + let uniform = Uniform::new(N::from(-1.), N::from(1.)); loop { let (x1, x2) = (uniform.sample(rng), uniform.sample(rng)); let sum = x1*x1 + x2*x2; - if sum >= 1. { + if sum >= N::from(1.) { continue; } - let factor = 2. * (1.0_f64 - sum).sqrt(); - return [x1 * factor, x2 * factor, 1. - 2.*sum]; + let factor = N::from(2.) * (N::from(1.0) - sum).sqrt(); + return [x1 * factor, x2 * factor, N::from(1.) - N::from(2.)*sum]; } } } @@ -70,7 +71,7 @@ mod tests { fn norm() { let mut rng = crate::test::rng(1); for _ in 0..1000 { - let x = UnitSphereSurface.sample(&mut rng); + let x: [f64; 3] = UnitSphereSurface.sample(&mut rng); assert_almost_eq!(x[0]*x[0] + x[1]*x[1] + x[2]*x[2], 1., 1e-15); } } @@ -78,11 +79,16 @@ mod tests { #[test] fn value_stability() { let mut rng = crate::test::rng(2); - assert_eq!(UnitSphereSurface.sample(&mut rng), - [-0.24950027180862533, -0.7552572587896719, 0.6060825747478084]); - assert_eq!(UnitSphereSurface.sample(&mut rng), - [0.47604534507233487, -0.797200864987207, -0.3712837328763685]); - assert_eq!(UnitSphereSurface.sample(&mut rng), - [0.9795722330927367, 0.18692349236651176, 0.07414747571708524]); + let expected = [ + [-0.24950027180862533, -0.7552572587896719, 0.6060825747478084], + [0.47604534507233487, -0.797200864987207, -0.3712837328763685], + [0.9795722330927367, 0.18692349236651176, 0.07414747571708524], + ]; + let samples: [[f64; 3]; 3] = [ + UnitSphereSurface.sample(&mut rng), + UnitSphereSurface.sample(&mut rng), + UnitSphereSurface.sample(&mut rng), + ]; + assert_eq!(samples, expected); } } diff --git a/rand_distr/src/utils.rs b/rand_distr/src/utils.rs index fcd81a472bb..e5f107f23e5 100644 --- a/rand_distr/src/utils.rs +++ b/rand_distr/src/utils.rs @@ -10,8 +10,73 @@ use rand::Rng; use crate::ziggurat_tables; - use rand::distributions::hidden_export::IntoFloat; +use core::{cmp, ops}; + +/// Trait for floating-point scalar types +/// +/// This allows many distributions to work with `f32` or `f64` parameters and is +/// potentially extensible. Note however that the `Exp1` and `StandardNormal` +/// distributions are implemented exclusively for `f32` and `f64`. +/// +/// The bounds and methods are based purely on internal +/// requirements, and will change as needed. +pub trait Float: Copy + Sized + cmp::PartialOrd + + ops::Neg + + ops::Add + + ops::Sub + + ops::Mul + + ops::Div + + ops::AddAssign + ops::SubAssign + ops::MulAssign + ops::DivAssign +{ + /// The constant π + fn pi() -> Self; + /// Support approximate representation of a f64 value + fn from(x: f64) -> Self; + + /// Take the absolute value of self + fn abs(self) -> Self; + + /// Take the exponential of self + fn exp(self) -> Self; + /// Take the natural logarithm of self + fn ln(self) -> Self; + /// Take square root of self + fn sqrt(self) -> Self; + /// Take self to a floating-point power + fn powf(self, power: Self) -> Self; + + /// Take the tangent of self + fn tan(self) -> Self; +} + +impl Float for f32 { + fn pi() -> Self { core::f32::consts::PI } + fn from(x: f64) -> Self { x as f32 } + + fn abs(self) -> Self { self.abs() } + + fn exp(self) -> Self { self.exp() } + fn ln(self) -> Self { self.ln() } + fn sqrt(self) -> Self { self.sqrt() } + fn powf(self, power: Self) -> Self { self.powf(power) } + + fn tan(self) -> Self { self.tan() } +} + +impl Float for f64 { + fn pi() -> Self { core::f64::consts::PI } + fn from(x: f64) -> Self { x } + + fn abs(self) -> Self { self.abs() } + + fn exp(self) -> Self { self.exp() } + fn ln(self) -> Self { self.ln() } + fn sqrt(self) -> Self { self.sqrt() } + fn powf(self, power: Self) -> Self { self.powf(power) } + + fn tan(self) -> Self { self.tan() } +} /// Calculates ln(gamma(x)) (natural logarithm of the gamma /// function) using the Lanczos approximation. @@ -26,7 +91,7 @@ use rand::distributions::hidden_export::IntoFloat; /// `Ag(z)` is an infinite series with coefficients that can be calculated /// ahead of time - we use just the first 6 terms, which is good enough /// for most purposes. -pub fn log_gamma(x: f64) -> f64 { +pub(crate) fn log_gamma(x: f64) -> f64 { // precalculated 6 coefficients for the first 6 terms of the series let coefficients: [f64; 6] = [ 76.18009172947146, @@ -71,7 +136,7 @@ pub fn log_gamma(x: f64) -> f64 { // the perf improvement (25-50%) is definitely worth the extra code // size from force-inlining. #[inline(always)] -pub fn ziggurat( +pub(crate) fn ziggurat( rng: &mut R, symmetric: bool, x_tab: ziggurat_tables::ZigTable, diff --git a/rand_distr/src/weibull.rs b/rand_distr/src/weibull.rs index 8d3c66d8919..5a05da92dc5 100644 --- a/rand_distr/src/weibull.rs +++ b/rand_distr/src/weibull.rs @@ -10,6 +10,7 @@ use rand::Rng; use crate::{Distribution, OpenClosed01}; +use crate::utils::Float; /// Samples floating-point numbers according to the Weibull distribution /// @@ -22,9 +23,9 @@ use crate::{Distribution, OpenClosed01}; /// println!("{}", val); /// ``` #[derive(Clone, Copy, Debug)] -pub struct Weibull { - inv_shape: f64, - scale: f64, +pub struct Weibull { + inv_shape: N, + scale: N, } /// Error type returned from `Weibull::new`. @@ -36,22 +37,26 @@ pub enum Error { ShapeTooSmall, } -impl Weibull { +impl Weibull +where OpenClosed01: Distribution +{ /// Construct a new `Weibull` distribution with given `scale` and `shape`. - pub fn new(scale: f64, shape: f64) -> Result { - if !(scale > 0.0) { + pub fn new(scale: N, shape: N) -> Result, Error> { + if !(scale > N::from(0.0)) { return Err(Error::ScaleTooSmall); } - if !(shape > 0.0) { + if !(shape > N::from(0.0)) { return Err(Error::ShapeTooSmall); } - Ok(Weibull { inv_shape: 1./shape, scale }) + Ok(Weibull { inv_shape: N::from(1.)/shape, scale }) } } -impl Distribution for Weibull { - fn sample(&self, rng: &mut R) -> f64 { - let x: f64 = rng.sample(OpenClosed01); +impl Distribution for Weibull +where OpenClosed01: Distribution +{ + fn sample(&self, rng: &mut R) -> N { + let x: N = rng.sample(OpenClosed01); self.scale * (-x.ln()).powf(self.inv_shape) } }