Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make distributions generic / impl for f32 #785

Merged
merged 4 commits into from May 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions rand_distr/Cargo.toml
Expand Up @@ -20,3 +20,4 @@ appveyor = { repository = "rust-random/rand" }

[dependencies]
rand = { path = "..", version = ">=0.5, <=0.7" }
num-traits = "0.2"
33 changes: 23 additions & 10 deletions rand_distr/src/exponential.rs
Expand Up @@ -12,6 +12,7 @@
use rand::Rng;
use crate::{ziggurat_tables, Distribution};
use crate::utils::ziggurat;
use num_traits::Float;

/// Samples floating-point numbers according to the exponential distribution,
/// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or
Expand Down Expand Up @@ -39,6 +40,15 @@ use crate::utils::ziggurat;
#[derive(Clone, Copy, Debug)]
pub struct Exp1;

impl Distribution<f32> for Exp1 {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f32 {
// TODO: use optimal 32-bit implementation
let x: f64 = self.sample(rng);
x as f32
}
}

// This could be done via `-rng.gen::<f64>().ln()` but that is slower.
impl Distribution<f64> for Exp1 {
#[inline]
Expand Down Expand Up @@ -76,9 +86,9 @@ impl Distribution<f64> for Exp1 {
/// println!("{} is from a Exp(2) distribution", v);
/// ```
#[derive(Clone, Copy, Debug)]
pub struct Exp {
pub struct Exp<N> {
/// `lambda` stored as `1/lambda`, since this is what we scale by.
lambda_inverse: f64
lambda_inverse: N
}

/// Error type returned from `Exp::new`.
Expand All @@ -88,22 +98,25 @@ pub enum Error {
LambdaTooSmall,
}

impl Exp {
impl<N: Float> Exp<N>
where Exp1: Distribution<N>
{
/// Construct a new `Exp` with the given shape parameter
/// `lambda`.
#[inline]
pub fn new(lambda: f64) -> Result<Exp, Error> {
if !(lambda > 0.0) {
pub fn new(lambda: N) -> Result<Exp<N>, Error> {
if !(lambda > N::zero()) {
return Err(Error::LambdaTooSmall);
}
Ok(Exp { lambda_inverse: 1.0 / lambda })
Ok(Exp { lambda_inverse: N::one() / lambda })
}
}

impl Distribution<f64> for Exp {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
let n: f64 = rng.sample(Exp1);
n * self.lambda_inverse
impl<N: Float> Distribution<N> for Exp<N>
where Exp1: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
rng.sample(Exp1) * self.lambda_inverse
}
}

Expand Down
109 changes: 62 additions & 47 deletions rand_distr/src/gamma.rs
Expand Up @@ -14,7 +14,8 @@ use self::ChiSquaredRepr::*;

use rand::Rng;
use crate::normal::StandardNormal;
use crate::{Distribution, Exp, Open01};
use crate::{Distribution, Exp1, Exp, Open01};
use num_traits::Float;

/// The Gamma distribution `Gamma(shape, scale)` distribution.
///
Expand Down Expand Up @@ -47,8 +48,8 @@ use crate::{Distribution, Exp, Open01};
/// (September 2000), 363-372.
/// DOI:[10.1145/358407.358414](https://doi.acm.org/10.1145/358407.358414)
#[derive(Clone, Copy, Debug)]
pub struct Gamma {
repr: GammaRepr,
pub struct Gamma<N> {
repr: GammaRepr<N>,
}

/// Error type returned from `Gamma::new`.
Expand All @@ -63,10 +64,10 @@ pub enum Error {
}

#[derive(Clone, Copy, Debug)]
enum GammaRepr {
Large(GammaLargeShape),
One(Exp),
Small(GammaSmallShape)
enum GammaRepr<N> {
Large(GammaLargeShape<N>),
One(Exp<N>),
Small(GammaSmallShape<N>)
}

// These two helpers could be made public, but saving the
Expand All @@ -84,37 +85,39 @@ enum GammaRepr {
/// See `Gamma` for sampling from a Gamma distribution with general
/// shape parameters.
#[derive(Clone, Copy, Debug)]
struct GammaSmallShape {
inv_shape: f64,
large_shape: GammaLargeShape
struct GammaSmallShape<N> {
inv_shape: N,
large_shape: GammaLargeShape<N>
}

/// Gamma distribution where the shape parameter is larger than 1.
///
/// See `Gamma` for sampling from a Gamma distribution with general
/// shape parameters.
#[derive(Clone, Copy, Debug)]
struct GammaLargeShape {
scale: f64,
c: f64,
d: f64
struct GammaLargeShape<N> {
scale: N,
c: N,
d: N
}

impl Gamma {
impl<N: Float> Gamma<N>
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
{
/// Construct an object representing the `Gamma(shape, scale)`
/// distribution.
#[inline]
pub fn new(shape: f64, scale: f64) -> Result<Gamma, Error> {
if !(shape > 0.0) {
pub fn new(shape: N, scale: N) -> Result<Gamma<N>, Error> {
if !(shape > N::zero()) {
return Err(Error::ShapeTooSmall);
}
if !(scale > 0.0) {
if !(scale > N::zero()) {
return Err(Error::ScaleTooSmall);
}

let repr = if shape == 1.0 {
One(Exp::new(1.0 / scale).map_err(|_| Error::ScaleTooLarge)?)
} else if shape < 1.0 {
let repr = if shape == N::one() {
One(Exp::new(N::one() / scale).map_err(|_| Error::ScaleTooLarge)?)
} else if shape < N::one() {
Small(GammaSmallShape::new_raw(shape, scale))
} else {
Large(GammaLargeShape::new_raw(shape, scale))
Expand All @@ -123,57 +126,69 @@ impl Gamma {
}
}

impl GammaSmallShape {
fn new_raw(shape: f64, scale: f64) -> GammaSmallShape {
impl<N: Float> GammaSmallShape<N>
where StandardNormal: Distribution<N>, Open01: Distribution<N>
{
fn new_raw(shape: N, scale: N) -> GammaSmallShape<N> {
GammaSmallShape {
inv_shape: 1. / shape,
large_shape: GammaLargeShape::new_raw(shape + 1.0, scale)
inv_shape: N::one() / shape,
large_shape: GammaLargeShape::new_raw(shape + N::one(), scale)
}
}
}

impl GammaLargeShape {
fn new_raw(shape: f64, scale: f64) -> GammaLargeShape {
let d = shape - 1. / 3.;
impl<N: Float> GammaLargeShape<N>
where StandardNormal: Distribution<N>, Open01: Distribution<N>
{
fn new_raw(shape: N, scale: N) -> GammaLargeShape<N> {
let d = shape - N::from(1. / 3.).unwrap();
GammaLargeShape {
scale,
c: 1. / (9. * d).sqrt(),
c: N::one() / (N::from(9.).unwrap() * d).sqrt(),
d
}
}
}

impl Distribution<f64> for Gamma {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
impl<N: Float> Distribution<N> for Gamma<N>
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
match self.repr {
Small(ref g) => g.sample(rng),
One(ref g) => g.sample(rng),
Large(ref g) => g.sample(rng),
}
}
}
impl Distribution<f64> for GammaSmallShape {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
let u: f64 = rng.sample(Open01);
impl<N: Float> Distribution<N> for GammaSmallShape<N>
where StandardNormal: Distribution<N>, Open01: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
let u: N = rng.sample(Open01);

self.large_shape.sample(rng) * u.powf(self.inv_shape)
}
}
impl Distribution<f64> for GammaLargeShape {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
impl<N: Float> Distribution<N> for GammaLargeShape<N>
where StandardNormal: Distribution<N>, Open01: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
// Marsaglia & Tsang method, 2000
loop {
let x = rng.sample(StandardNormal);
let v_cbrt = 1.0 + self.c * x;
if v_cbrt <= 0.0 { // a^3 <= 0 iff a <= 0
let x: N = rng.sample(StandardNormal);
let v_cbrt = N::one() + self.c * x;
if v_cbrt <= N::zero() { // a^3 <= 0 iff a <= 0
continue
}

let v = v_cbrt * v_cbrt * v_cbrt;
let u: f64 = rng.sample(Open01);
let u: N = rng.sample(Open01);

let x_sqr = x * x;
if u < 1.0 - 0.0331 * x_sqr * x_sqr ||
u.ln() < 0.5 * x_sqr + self.d * (1.0 - v + v.ln()) {
if u < N::one() - N::from(0.0331).unwrap() * x_sqr * x_sqr ||
u.ln() < N::from(0.5).unwrap() * x_sqr + self.d * (N::one() - v + v.ln())
{
return self.d * v * self.scale
}
}
Expand Down Expand Up @@ -215,7 +230,7 @@ enum ChiSquaredRepr {
// e.g. when alpha = 1/2 as it would be for this case, so special-
// casing and using the definition of N(0,1)^2 is faster.
DoFExactlyOne,
DoFAnythingElse(Gamma),
DoFAnythingElse(Gamma<f64>),
}

impl ChiSquared {
Expand All @@ -238,7 +253,7 @@ impl Distribution<f64> for ChiSquared {
match self.repr {
DoFExactlyOne => {
// k == 1 => N(0,1)^2
let norm = rng.sample(StandardNormal);
let norm: f64 = rng.sample(StandardNormal);
norm * norm
}
DoFAnythingElse(ref g) => g.sample(rng)
Expand Down Expand Up @@ -332,7 +347,7 @@ impl StudentT {
}
impl Distribution<f64> for StudentT {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
let norm = rng.sample(StandardNormal);
let norm: f64 = rng.sample(StandardNormal);
norm * (self.dof / self.chi.sample(rng)).sqrt()
}
}
Expand All @@ -350,8 +365,8 @@ impl Distribution<f64> for StudentT {
/// ```
#[derive(Clone, Copy, Debug)]
pub struct Beta {
gamma_a: Gamma,
gamma_b: Gamma,
gamma_a: Gamma<f64>,
gamma_b: Gamma<f64>,
}

/// Error type returned from `Beta::new`.
Expand Down