From f1cb78bbee488b0e326b1e7117400e19ea3816e0 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 3 Dec 2020 14:37:53 -0800 Subject: [PATCH] CUDA BF16 backwards --- .../cuda/BinaryMiscBackwardOpsKernels.cu | 60 +++++++++---------- 1 file changed, 27 insertions(+), 33 deletions(-) diff --git a/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu index ed7e2190f75e..a385aa721522 100644 --- a/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu +++ b/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu @@ -16,10 +16,8 @@ namespace native { void sigmoid_backward_kernel_cuda(TensorIterator& iter) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "sigmoid_backward_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "sigmoid_backward_cuda", [&] { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return a * (scalar_t(1.) - b) * b; - }); + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return a * (scalar_t(1.) - b) * b; }); }); } @@ -31,31 +29,29 @@ void logit_backward_kernel_cuda(TensorIterator& iter, Scalar eps_scalar) { iter.dtype(), "logit_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "logit_cuda", [&] { - using T_ACC = acc_type; - const T_ACC eps = eps_scalar.to(); - if (eps < T_ACC(0)) { - gpu_kernel( - iter, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { - const T_ACC dy_acc = static_cast(dy); - const T_ACC x_acc = static_cast(x); - return (x_acc < T_ACC(0) || x_acc > T_ACC(1)) - ? std::numeric_limits::quiet_NaN() - : dy_acc / (x_acc * (T_ACC(1) - x_acc)); - }); - } else { - const T_ACC lo = eps; - const T_ACC hi = T_ACC(1) - eps; - gpu_kernel( - iter, [lo, hi] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { - const T_ACC dy_acc = static_cast(dy); - const T_ACC x_acc = static_cast(x); - return (x_acc < lo || x_acc > hi) - ? T_ACC(0) - : dy_acc / (x_acc * (T_ACC(1) - x_acc)); - }); - } - }); + using T_ACC = acc_type; + const T_ACC eps = eps_scalar.to(); + if (eps < T_ACC(0)) { + gpu_kernel( + iter, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + const T_ACC dy_acc = static_cast(dy); + const T_ACC x_acc = static_cast(x); + return (x_acc < T_ACC(0) || x_acc > T_ACC(1)) + ? std::numeric_limits::quiet_NaN() + : dy_acc / (x_acc * (T_ACC(1) - x_acc)); + }); + } else { + const T_ACC lo = eps; + const T_ACC hi = T_ACC(1) - eps; + gpu_kernel( + iter, [lo, hi] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + const T_ACC dy_acc = static_cast(dy); + const T_ACC x_acc = static_cast(x); + return (x_acc < lo || x_acc > hi) + ? T_ACC(0) + : dy_acc / (x_acc * (T_ACC(1) - x_acc)); + }); + } }); } @@ -68,10 +64,8 @@ void tanh_backward_kernel_cuda(TensorIterator& iter) { }); } else { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "tanh_backward_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "tanh_backward_cuda", [&] { - gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return a * (scalar_t{1.} - b * b); - }); + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return a * (scalar_t{1.} - b * b); }); }); }