forked from celeritas-project/celeritas
-
Notifications
You must be signed in to change notification settings - Fork 0
/
NormalDistribution.hh
103 lines (91 loc) · 3.2 KB
/
NormalDistribution.hh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
//----------------------------------*-C++-*----------------------------------//
// Copyright 2021-2024 UT-Battelle, LLC, and other Celeritas developers.
// See the top-level COPYRIGHT file for details.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file celeritas/random/distribution/NormalDistribution.hh
//---------------------------------------------------------------------------//
#pragma once
#include <cmath>
#include <type_traits>
#include "corecel/Assert.hh"
#include "corecel/Macros.hh"
#include "corecel/Types.hh"
#include "corecel/math/Algorithms.hh"
#include "celeritas/Constants.hh"
#include "GenerateCanonical.hh"
namespace celeritas
{
//---------------------------------------------------------------------------//
/*!
* Sample from a normal distribution.
*
* This uses the Box-Muller transform to generate pairs of independent,
* normally distributed random numbers, returning them one at a time. Two
* random numbers uniformly distributed on [0, 1] are mapped to two
* independent, standard, normally distributed samples using the relations:
* \f[
x_1 = \sqrt{-2 \ln \xi_1} \cos(2 \pi \xi_2)
x_2 = \sqrt{-2 \ln \xi_1} \sin(2 \pi \xi_2)
\f]
*/
template<class RealType = ::celeritas::real_type>
class NormalDistribution
{
static_assert(std::is_floating_point_v<RealType>);
public:
//!@{
//! \name Type aliases
using real_type = RealType;
using result_type = real_type;
//!@}
public:
// Construct with mean and standard deviation
explicit inline CELER_FUNCTION
NormalDistribution(real_type mean = 0, real_type stddev = 1);
// Sample a random number according to the distribution
template<class Generator>
inline CELER_FUNCTION result_type operator()(Generator& rng);
private:
real_type const mean_;
real_type const stddev_;
real_type spare_{};
bool has_spare_{false};
};
//---------------------------------------------------------------------------//
// INLINE DEFINITIONS
//---------------------------------------------------------------------------//
/*!
* Construct with mean and standard deviation.
*/
template<class RealType>
CELER_FUNCTION
NormalDistribution<RealType>::NormalDistribution(real_type mean,
real_type stddev)
: mean_(mean), stddev_(stddev)
{
CELER_EXPECT(stddev > 0);
}
//---------------------------------------------------------------------------//
/*!
* Sample a random number according to the distribution.
*/
template<class RealType>
template<class Generator>
CELER_FUNCTION auto NormalDistribution<RealType>::operator()(Generator& rng)
-> result_type
{
if (has_spare_)
{
has_spare_ = false;
return std::fma(spare_, stddev_, mean_);
}
constexpr auto twopi = static_cast<RealType>(2 * m_pi);
real_type theta = twopi * generate_canonical<RealType>(rng);
real_type r = std::sqrt(-2 * std::log(generate_canonical<RealType>(rng)));
spare_ = r * std::cos(theta);
has_spare_ = true;
return std::fma(r * std::sin(theta), stddev_, mean_);
}
//---------------------------------------------------------------------------//
} // namespace celeritas