Skip to content

Commit

Permalink
Merge pull request #793 from dhardy/distr
Browse files Browse the repository at this point in the history
Generic distributions: use custom trait
  • Loading branch information
dhardy committed May 15, 2019
2 parents b664e64 + b330c21 commit 0f1b1ff
Show file tree
Hide file tree
Showing 14 changed files with 288 additions and 166 deletions.
1 change: 0 additions & 1 deletion rand_distr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,3 @@ appveyor = { repository = "rust-random/rand" }

[dependencies]
rand = { path = "..", version = ">=0.5, <=0.7" }
num-traits = "0.2"
30 changes: 17 additions & 13 deletions rand_distr/src/cauchy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
//! The Cauchy distribution.

use rand::Rng;
use crate::Distribution;
use std::f64::consts::PI;
use crate::{Distribution, Standard};
use crate::utils::Float;

/// The Cauchy distribution `Cauchy(median, scale)`.
///
Expand All @@ -28,9 +28,9 @@ use std::f64::consts::PI;
/// println!("{} is from a Cauchy(2, 5) distribution", v);
/// ```
#[derive(Clone, Copy, Debug)]
pub struct Cauchy {
median: f64,
scale: f64
pub struct Cauchy<N> {
median: N,
scale: N,
}

/// Error type returned from `Cauchy::new`.
Expand All @@ -40,11 +40,13 @@ pub enum Error {
ScaleTooSmall,
}

impl Cauchy {
impl<N: Float> Cauchy<N>
where Standard: Distribution<N>
{
/// Construct a new `Cauchy` with the given shape parameters
/// `median` the peak location and `scale` the scale factor.
pub fn new(median: f64, scale: f64) -> Result<Cauchy, Error> {
if !(scale > 0.0) {
pub fn new(median: N, scale: N) -> Result<Cauchy<N>, Error> {
if !(scale > N::from(0.0)) {
return Err(Error::ScaleTooSmall);
}
Ok(Cauchy {
Expand All @@ -54,13 +56,15 @@ impl Cauchy {
}
}

impl Distribution<f64> for Cauchy {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
impl<N: Float> Distribution<N> for Cauchy<N>
where Standard: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
// sample from [0, 1)
let x = rng.gen::<f64>();
let x = Standard.sample(rng);
// get standard cauchy random number
// note that π/2 is not exactly representable, even if x=0.5 the result is finite
let comp_dev = (PI * x).tan();
let comp_dev = (N::pi() * x).tan();
// shift and scale according to parameters
let result = self.median + self.scale * comp_dev;
result
Expand Down Expand Up @@ -99,7 +103,7 @@ mod test {
fn test_cauchy_mean() {
let cauchy = Cauchy::new(10.0, 5.0).unwrap();
let mut rng = crate::test::rng(123);
let mut sum = 0.0;
let mut sum = 0.0f64;
for _ in 0..1000 {
sum += cauchy.sample(&mut rng);
}
Expand Down
34 changes: 19 additions & 15 deletions rand_distr/src/dirichlet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
//! The dirichlet distribution.

use rand::Rng;
use crate::Distribution;
use crate::gamma::Gamma;
use crate::{Distribution, Gamma, StandardNormal, Exp1, Open01};
use crate::utils::Float;

/// The dirichelet distribution `Dirichlet(alpha)`.
///
Expand All @@ -30,9 +30,9 @@ use crate::gamma::Gamma;
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
/// ```
#[derive(Clone, Debug)]
pub struct Dirichlet {
pub struct Dirichlet<N> {
/// Concentration parameters (alpha)
alpha: Vec<f64>,
alpha: Vec<N>,
}

/// Error type returned from `Dirchlet::new`.
Expand All @@ -46,18 +46,20 @@ pub enum Error {
SizeTooSmall,
}

impl Dirichlet {
impl<N: Float> Dirichlet<N>
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
{
/// Construct a new `Dirichlet` with the given alpha parameter `alpha`.
///
/// Requires `alpha.len() >= 2`.
#[inline]
pub fn new<V: Into<Vec<f64>>>(alpha: V) -> Result<Dirichlet, Error> {
pub fn new<V: Into<Vec<N>>>(alpha: V) -> Result<Dirichlet<N>, Error> {
let a = alpha.into();
if a.len() < 2 {
return Err(Error::AlphaTooShort);
}
for i in 0..a.len() {
if !(a[i] > 0.0) {
if !(a[i] > N::from(0.0)) {
return Err(Error::AlphaTooSmall);
}
}
Expand All @@ -69,8 +71,8 @@ impl Dirichlet {
///
/// Requires `size >= 2`.
#[inline]
pub fn new_with_size(alpha: f64, size: usize) -> Result<Dirichlet, Error> {
if !(alpha > 0.0) {
pub fn new_with_size(alpha: N, size: usize) -> Result<Dirichlet<N>, Error> {
if !(alpha > N::from(0.0)) {
return Err(Error::AlphaTooSmall);
}
if size < 2 {
Expand All @@ -82,18 +84,20 @@ impl Dirichlet {
}
}

impl Distribution<Vec<f64>> for Dirichlet {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<f64> {
impl<N: Float> Distribution<Vec<N>> for Dirichlet<N>
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<N> {
let n = self.alpha.len();
let mut samples = vec![0.0f64; n];
let mut sum = 0.0f64;
let mut samples = vec![N::from(0.0); n];
let mut sum = N::from(0.0);

for i in 0..n {
let g = Gamma::new(self.alpha[i], 1.0).unwrap();
let g = Gamma::new(self.alpha[i], N::from(1.0)).unwrap();
samples[i] = g.sample(rng);
sum += samples[i];
}
let invacc = 1.0 / sum;
let invacc = N::from(1.0) / sum;
for i in 0..n {
samples[i] *= invacc;
}
Expand Down
7 changes: 3 additions & 4 deletions rand_distr/src/exponential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@

use rand::Rng;
use crate::{ziggurat_tables, Distribution};
use crate::utils::ziggurat;
use num_traits::Float;
use crate::utils::{ziggurat, Float};

/// Samples floating-point numbers according to the exponential distribution,
/// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or
Expand Down Expand Up @@ -105,10 +104,10 @@ where Exp1: Distribution<N>
/// `lambda`.
#[inline]
pub fn new(lambda: N) -> Result<Exp<N>, Error> {
if !(lambda > N::zero()) {
if !(lambda > N::from(0.0)) {
return Err(Error::LambdaTooSmall);
}
Ok(Exp { lambda_inverse: N::one() / lambda })
Ok(Exp { lambda_inverse: N::from(1.0) / lambda })
}
}

Expand Down
Loading

0 comments on commit 0f1b1ff

Please sign in to comment.