Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix uniform returning end point for BFloat16 and Half #96962

Closed
wants to merge 5 commits into from
Closed
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 10 additions & 6 deletions aten/src/ATen/native/cuda/DistributionTemplates.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandBase.h>
#include <ATen/OpMathType.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>
#include <c10/util/Half.h>
Expand Down Expand Up @@ -458,19 +459,22 @@ void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG gen)
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_kernel_cuda", [&] {
auto from = static_cast<scalar_t>(from_);
auto to = static_cast<scalar_t>(to_);
using accscalar_t = at::acc_type<scalar_t, true>;
auto range = static_cast<accscalar_t>(to-from);
using opmath_t = at::opmath_type<scalar_t>;
auto range = static_cast<opmath_t>(to-from);
// define lambda to reverse bounds, multiply 'range' and add 'from_'
auto uniform_func = [range, from] __device__ (accscalar_t rand) {
auto uniform_func = [range, from, to] __device__ (opmath_t rand) {
// Compute output value before reversing the bounds
// BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/96947
auto value = static_cast<scalar_t>(rand * range + from);
// reverse the bounds of curand4 from (0, 1] to [0, 1)
// Note that this method is from legacy THCTensorRandom and is likely to give
// you more 0-s, since, the probability of gettings 1-s is higher than 0-s and
// by reversing the bounds, we are flipping the probabilities of 1-s and 0-s.
// BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706
auto reverse_bound_rand = rand == static_cast<accscalar_t>(1.0) ? static_cast<accscalar_t>(0.0) : rand;
return static_cast<scalar_t>(reverse_bound_rand * range + from);
auto reverse_bound_value = value == to ? from : value;
return reverse_bound_value;
};
uniform_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, uniform_func);
uniform_and_transform<scalar_t, opmath_t, curand4_engine_calls>(iter, gen, uniform_func);
});
}

Expand Down