Skip to content

Commit

Permalink
Filter 0's returned by exponential distribution (#53480)
Browse files Browse the repository at this point in the history
Summary:
Fixes #48841 for half datatype (it was fixed for other datatypes before).
The reason for #48841 happening for half was that `exponential_` for half was producing 0s.
Exponential distribution implementation on cuda is here https://github.com/pytorch/pytorch/blob/e08aae261397b8da3e71024bbeddfe0487185d1d/aten/src/ATen/native/cuda/DistributionTemplates.h#L535-L545
with `transformation::exponential` defined here
https://github.com/pytorch/pytorch/blob/e08aae261397b8da3e71024bbeddfe0487185d1d/aten/src/ATen/core/TransformationHelper.h#L113-L123
It takes a uniformly distributed random number and takes `log` of it. If necessary, the result is then converted to low precision datatype (half). To avoid 0's, before applying `log`,  ones are replaced with std::nextafter(1,0). This seems fine, because log(1-eps) is still representable in half precision (`torch.tensor([1.], device="cuda").nextafter(torch.tensor([0.], device="cuda")).log().half()` produces 5.96e-8) , so casting to `scalar_t` should work. However, since fast log approximation is used (`__logf`), the log result is ~3e-9 instead of more accurate 5.96e-8, and underflows when casting to half. Using `::log` instead of fast approximation fixes it, however, it comes with ~20% perf penalty on exponential kernel for fp32 datatype, probably more for half.

Edit: alternative approach used now is to filter all small values returned by transformation. The result is equivalent to squashing of 1's to 1-eps that was used before, and computing correct log of 1-eps (which is -eps, exactly equal even for doubles). This doesn't incur noticeable performance hit.

Pull Request resolved: #53480

Reviewed By: mruberry

Differential Revision: D26924622

Pulled By: ngimel

fbshipit-source-id: dc1329e4773bf91f26af23c8afa0ae845cfb0937
  • Loading branch information
ngimel authored and facebook-github-bot committed Mar 10, 2021
1 parent c5cd993 commit 6aa5148
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
9 changes: 8 additions & 1 deletion aten/src/ATen/core/TransformationHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,14 @@ C10_HOST_DEVICE __ubsan_ignore_float_divide_by_zero__ inline T exponential(T val
// TODO: must be investigated and unified!!!
// https://github.com/pytorch/pytorch/issues/38662
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<T>(-1.0) / lambda * at::log(val);
// BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706
// curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0.
// we need log to be not 0, and not underflow when converted to half
// fast __logf approximation can underflow, so set log to -epsilon/2 for 1 or close to 1 args
auto log = val >= static_cast<T>(1.) - std::numeric_limits<T>::epsilon() / 2
? -std::numeric_limits<T>::epsilon() / 2
: at::log(val);
return static_cast<T>(-1.0) / lambda * log;
#else
return static_cast<T>(-1.0) / lambda * at::log(static_cast<T>(1.0) - val);
#endif
Expand Down
6 changes: 0 additions & 6 deletions aten/src/ATen/native/cuda/DistributionTemplates.h
Original file line number Diff line number Diff line change
Expand Up @@ -533,12 +533,6 @@ void exponential_kernel(TensorIterator& iter, double lambda_, RNG gen) {
auto lambda = static_cast<accscalar_t>(lambda_);
// define lambda for exponential transformation
auto exponential_func = [lambda] __device__ (accscalar_t rand) {
// curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0.
// Hence, squash the 1 to just below 1.
// BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706
if(rand == static_cast<accscalar_t>(1.0)) {
rand = ::nextafter(static_cast<accscalar_t>(1.0), static_cast<accscalar_t>(0.0));
}
return static_cast<scalar_t>(transformation::exponential<accscalar_t>(rand, lambda));
};
uniform_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, exponential_func);
Expand Down
11 changes: 11 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3724,6 +3724,17 @@ def test_exponential(self, device, dtype):
with self.assertRaises(RuntimeError):
torch.empty((1,), device=device, dtype=dtype).exponential_(-0.5)

@onlyCUDA
@dtypesIfCUDA(torch.half, torch.float)
def test_exponential_no_zero(self, device, dtype):
# naively, 0 in exponential can be generated with probability 2^-24
# so we need more samples to check if it's not generated
# instead of doing one
# don't test CPU, that would be a long test
x = torch.empty(50000000, device=device, dtype=dtype).exponential_()
self.assertTrue(x.min() > 0)


@skipIfNoSciPy
@dtypes(*torch.testing.get_all_fp_dtypes())
def test_uniform_kstest(self, device, dtype):
Expand Down

0 comments on commit 6aa5148

Please sign in to comment.