Skip to content

Commit

Permalink
Location-shift MKL Exponential Distribution (#101720)
Browse files Browse the repository at this point in the history
  • Loading branch information
min-jean-cho authored and pytorchmergebot committed May 25, 2023
1 parent d4380ed commit 3ca068b
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions aten/src/ATen/native/cpu/DistributionKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,19 @@ void exponential_kernel(TensorIteratorBase &iter, double lambda, c10::optional<G
using tmp_scalar_t = typename std::conditional_t<std::is_same<scalar_t, double>::value, double, float>;
tmp_scalar_t *sample_ptr = tmp_tensor.data_ptr<tmp_scalar_t>();

// Intel MKL vRngExponential variate originally does not exclude 0.
// However, to align with pytorch exponential variate definition which excludes 0,
// we shift the MKL vRngExponential distribution location by adding a very small constant, eps.
// If X ~ Exp(lambda), then E(X) = 1/lambda, and V(X) = 1/lambda**2.
// If Y = X + eps, where eps ~= 0, then E(Y) = (1/lambda) + eps, and V(Y) = 1/lambda**2.
// If eps is very small, the two distributions are indistinguishable, and are almost identical.
// The detail of location-shifted MKL vRngExponential is as follows.
// PDF: f(x) = lambda * exp( -lambda * (x - eps) )
// CDF: F(x) = 1 - exp( -lambda * (x - eps) )
// Mean: E[X+eps] = (1/lambda) + eps
// Variance: V[X+eps] = 1/lambda**2
auto eps = std::numeric_limits<tmp_scalar_t>::min();

auto sample = [&](int64_t begin, int64_t end) {
int64_t len = end - begin;
if (len > 0) {
Expand All @@ -149,13 +162,13 @@ void exponential_kernel(TensorIteratorBase &iter, double lambda, c10::optional<G
vslNewStream(&stream, VSL_BRNG_MCG31, seed);
vslSkipAheadStream(stream, begin);
vdRngExponential(VSL_RNG_METHOD_EXPONENTIAL_ICDF, stream, len,
(double *)(sample_ptr + begin), 0, 1./lambda);
(double *)(sample_ptr + begin), eps, 1./lambda);
vslDeleteStream(&stream);
} else {
vslNewStream(&stream, VSL_BRNG_MCG31, seed);
vslSkipAheadStream(stream, begin);
vsRngExponential(VSL_RNG_METHOD_EXPONENTIAL_ICDF, stream, len,
(float *) (sample_ptr + begin), 0, 1./lambda);
(float *) (sample_ptr + begin), eps, 1./lambda);
vslDeleteStream(&stream);
}
// vectorized copy if using buffer and contiguous
Expand Down

0 comments on commit 3ca068b

Please sign in to comment.