Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA BFloat16 signal windows #45155

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
102 changes: 50 additions & 52 deletions aten/src/ATen/native/cuda/RangeFactories.cu
Expand Up @@ -207,62 +207,60 @@ Tensor& range_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) {

Tensor& arange_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, result.scalar_type(), "arange_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "arange_cuda", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
auto xstart = start.to<accscalar_t>();
auto xend = end.to<accscalar_t>();
auto xstep = step.to<accscalar_t>();

// we use double precision for (start - end) / step
// to compute size_d for consistency across devices.
// The problem with using accscalar_t is that accscalar_t might be float32 on gpu for a float32 scalar_t,
// but double on cpu for the same,
// and the effective output size starts differing on CPU vs GPU because of precision issues, which
// we dont want.
// the corner-case we do want to take into account is int64_t, which has higher precision than double
double size_d;
if (std::is_same<scalar_t, int64_t>::value) {
size_d = std::ceil(static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>())
/ step.to<accscalar_t>());
} else {
size_d = std::ceil(static_cast<double>(end.to<double>() - start.to<double>())
/ step.to<double>());
}

TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
std::isfinite(static_cast<double>(xend)),
"unsupported range: ", xstart, " -> ", xend);
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
"upper bound and larger bound inconsistent with step sign");

TORCH_CHECK(size_d >= 0 && size_d <= static_cast<double>(std::numeric_limits<int64_t>::max()),
"invalid size, possible overflow?");
int64_t size = static_cast<int64_t>(size_d);
int64_t numel = result.numel();

if (numel != size) {
if(numel > 0){
TORCH_WARN("The number of elements in the out tensor of shape ", result.sizes(),
" is ", numel, " which does not match the computed number of elements ", size,
". Note that this may occur as a result of rounding error. "
"The out tensor will be resized to a tensor of shape (", size, ",).");
}
result.resize_({size});
}
bool is_contiguous = result.is_contiguous();
Tensor r = !is_contiguous ? at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : result;
using accscalar_t = at::acc_type<scalar_t, true>;
auto xstart = start.to<accscalar_t>();
auto xend = end.to<accscalar_t>();
auto xstep = step.to<accscalar_t>();

gpu_kernel_with_index(r, [xstart, xstep]GPU_LAMBDA(int64_t ind) -> scalar_t {
accscalar_t inc = xstep * static_cast<accscalar_t>(ind);
accscalar_t val = xstart + inc;
return static_cast<scalar_t>(val);
});
// we use double precision for (start - end) / step
// to compute size_d for consistency across devices.
// The problem with using accscalar_t is that accscalar_t might be float32 on gpu for a float32 scalar_t,
// but double on cpu for the same,
// and the effective output size starts differing on CPU vs GPU because of precision issues, which
// we dont want.
// the corner-case we do want to take into account is int64_t, which has higher precision than double
double size_d;
if (std::is_same<scalar_t, int64_t>::value) {
size_d = std::ceil(static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>())
/ step.to<accscalar_t>());
} else {
size_d = std::ceil(static_cast<double>(end.to<double>() - start.to<double>())
/ step.to<double>());
}

if(!is_contiguous) {
result.copy_(r);
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
std::isfinite(static_cast<double>(xend)),
"unsupported range: ", xstart, " -> ", xend);
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
"upper bound and larger bound inconsistent with step sign");

TORCH_CHECK(size_d >= 0 && size_d <= static_cast<double>(std::numeric_limits<int64_t>::max()),
"invalid size, possible overflow?");
int64_t size = static_cast<int64_t>(size_d);
int64_t numel = result.numel();

if (numel != size) {
if(numel > 0){
TORCH_WARN("The number of elements in the out tensor of shape ", result.sizes(),
" is ", numel, " which does not match the computed number of elements ", size,
". Note that this may occur as a result of rounding error. "
"The out tensor will be resized to a tensor of shape (", size, ",).");
}
result.resize_({size});
}
bool is_contiguous = result.is_contiguous();
Tensor r = !is_contiguous ? at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : result;

gpu_kernel_with_index(r, [xstart, xstep]GPU_LAMBDA(int64_t ind) -> scalar_t {
accscalar_t inc = xstep * static_cast<accscalar_t>(ind);
accscalar_t val = xstart + inc;
return static_cast<scalar_t>(val);
});

if(!is_contiguous) {
result.copy_(r);
}
});

AT_CUDA_CHECK(cudaGetLastError());
Expand Down
15 changes: 9 additions & 6 deletions aten/src/ATen/native/cuda/UnaryOpsKernel.cu
Expand Up @@ -223,13 +223,16 @@ void nan_to_num_kernel_cuda(
});
}

void kaiser_window_kernel_cuda(TensorIterator& iter, int64_t window_length, double beta){
void kaiser_window_kernel_cuda(TensorIterator& iter, int64_t window_length, double beta_){
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "kaiser_window_cuda", [&](){
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "kaiser_window_cuda", [&] {
const scalar_t alpha = static_cast<scalar_t>((window_length - 1) / 2.0);
gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t a) -> scalar_t {
return calc_i0(static_cast<scalar_t>(beta) * ::sqrt(1 - ::pow((a - alpha) / alpha, static_cast<scalar_t>(2.0)))) / calc_i0(static_cast<scalar_t>(beta));
});
using T_ACC = acc_type<scalar_t, true>;
const T_ACC inv_alpha = static_cast<T_ACC>(2.0 / (window_length - 1));
const T_ACC beta = static_cast<T_ACC>(beta_);
const T_ACC inv_i0_beta = 1.0 / calc_i0(beta);
gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t a) -> scalar_t {
T_ACC x = static_cast<T_ACC>(a) * inv_alpha - 1;
T_ACC y = std::max<T_ACC>(0, 1 - x * x);
return calc_i0(beta * ::sqrt(y)) * inv_i0_beta;
});
});
}
Expand Down
15 changes: 10 additions & 5 deletions test/test_torch.py
Expand Up @@ -8092,21 +8092,26 @@ def test_complex_rot90(self, device, dtype):
self.compare_with_numpy(torch_fn, np_fn, data)

@onlyOnCPUAndCUDA
@precisionOverride({torch.bfloat16: 5e-2, torch.half: 1e-3})
@unittest.skipIf(not TEST_SCIPY, "Scipy not found")
def test_signal_window_functions(self, device):
@dtypesIfCUDA(torch.float, torch.double, torch.bfloat16, torch.half, torch.long)
@dtypesIfCPU(torch.float, torch.double, torch.long)
def test_signal_window_functions(self, device, dtype):

def test(name, kwargs):
torch_method = getattr(torch, name + '_window')
if not dtype.is_floating_point:
with self.assertRaisesRegex(RuntimeError, r'floating point'):
torch_method(3, dtype=dtype)
return
for size in [0, 1, 2, 5, 10, 50, 100, 1024, 2048]:
for periodic in [True, False]:
res = torch_method(size, periodic=periodic, **kwargs, device=device)
# NB: scipy always returns a float32 result
res = torch_method(size, periodic=periodic, **kwargs, device=device, dtype=dtype)
# NB: scipy always returns a float64 result
ref = torch.from_numpy(signal.get_window((name, *(kwargs.values())), size, fftbins=periodic))
self.assertEqual(res, ref, exact_dtype=False)
with self.assertRaisesRegex(RuntimeError, r'not implemented for sparse types'):
torch_method(3, layout=torch.sparse_coo)
with self.assertRaisesRegex(RuntimeError, r'floating point'):
torch_method(3, dtype=torch.long)
self.assertTrue(torch_method(3, requires_grad=True).requires_grad)
self.assertFalse(torch_method(3).requires_grad)

Expand Down