Skip to content

Commit

Permalink
CUDA BFloat activations 1 (#44834)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #44834

Reviewed By: mruberry

Differential Revision: D23752660

Pulled By: ngimel

fbshipit-source-id: 209a937e8a9afe12b7dd86ecfa493c9417fd22fb
  • Loading branch information
zasdfgbnm authored and facebook-github-bot committed Sep 18, 2020
1 parent 76a109c commit 0638940
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 56 deletions.
90 changes: 35 additions & 55 deletions aten/src/ATen/native/cuda/Activation.cu
Expand Up @@ -246,33 +246,27 @@ std::tuple<Tensor, Tensor> prelu_backward_cuda(const Tensor& grad_out_, const Te
// -----------------------------------
void hardshrink_kernel(TensorIterator& iter, Scalar value) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardshrink_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "hardshrink_cuda", [&] {
auto lambd = value.to<scalar_t>();
gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t a) -> scalar_t {
return (a >= -lambd && a <= lambd) ? scalar_t(0) : a;
});
auto lambd = value.to<scalar_t>();
gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t a) -> scalar_t {
return (a >= -lambd && a <= lambd) ? scalar_t(0) : a;
});
});
}

void softshrink_kernel(TensorIterator& iter, Scalar value) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "softshrink_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "softshrink_cuda", [&] {
auto lambd = value.to<scalar_t>();
gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t a) -> scalar_t {
return a > lambd ? a - lambd : (a < -lambd ? a + lambd : scalar_t(0));
});
auto lambd = value.to<scalar_t>();
gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t a) -> scalar_t {
return a > lambd ? a - lambd : (a < -lambd ? a + lambd : scalar_t(0));
});
});
}

void shrink_backward_kernel(TensorIterator& iter, Scalar value) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "shrink_backward_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "shrink_backward_cuda", [&] {
auto lambd = value.to<scalar_t>();
gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t grad_val, scalar_t self_val) -> scalar_t {
return (self_val >= -lambd && self_val <= lambd) ? scalar_t(0) : grad_val;
});
auto lambd = value.to<scalar_t>();
gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t grad_val, scalar_t self_val) -> scalar_t {
return (self_val >= -lambd && self_val <= lambd) ? scalar_t(0) : grad_val;
});
});
}
Expand All @@ -289,25 +283,21 @@ void hardtanh_backward_kernel(TensorIterator& iter, Scalar min, Scalar max) {

void softplus_kernel(TensorIterator& iter, Scalar beta_, Scalar threshold_) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "softplus_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "softplus_cuda", [&] {
auto beta = beta_.to<scalar_t>();
auto threshold = threshold_.to<scalar_t>();
gpu_kernel(iter, [beta, threshold]GPU_LAMBDA(scalar_t a) -> scalar_t {
return (a * beta) > threshold ? a : static_cast<scalar_t>(::log1p(std::exp(a * beta))) / beta;
});
auto beta = beta_.to<scalar_t>();
auto threshold = threshold_.to<scalar_t>();
gpu_kernel(iter, [beta, threshold]GPU_LAMBDA(scalar_t a) -> scalar_t {
return (a * beta) > threshold ? a : static_cast<scalar_t>(::log1p(std::exp(a * beta))) / beta;
});
});
}

void softplus_backward_kernel(TensorIterator& iter, Scalar beta_, Scalar threshold_) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "softplus_backward_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "softplus_backward_cuda", [&] {
auto beta = beta_.to<scalar_t>();
auto threshold = threshold_.to<scalar_t>();
gpu_kernel(iter, [beta, threshold]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
scalar_t z = std::exp(b * beta);
return (b * beta) > threshold ? a : a * (z - scalar_t(1.)) / z;
});
auto beta = beta_.to<scalar_t>();
auto threshold = threshold_.to<scalar_t>();
gpu_kernel(iter, [beta, threshold]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
scalar_t z = std::exp(b * beta);
return (b * beta) > threshold ? a : a * (z - scalar_t(1.)) / z;
});
});
}
Expand All @@ -321,34 +311,28 @@ void threshold_kernel_impl(TensorIterator& iter, scalar_t threshold, scalar_t va

static void threshold_kernel(TensorIterator& iter, Scalar threshold, Scalar value) {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "threshold_cuda", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "threshold_cuda", [&] {
threshold_kernel_impl<scalar_t>(iter, threshold.to<scalar_t>(), value.to<scalar_t>());
});
threshold_kernel_impl<scalar_t>(iter, threshold.to<scalar_t>(), value.to<scalar_t>());
});
}

void elu_kernel(TensorIterator& iter, Scalar alpha, Scalar scale, Scalar input_scale) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "elu_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "elu_cuda", [&] {
auto negcoef = alpha.to<scalar_t>() * scale.to<scalar_t>();
auto poscoef = scale.to<scalar_t>();
auto negiptcoef = input_scale.to<scalar_t>();
gpu_kernel(iter, [negcoef, poscoef, negiptcoef]GPU_LAMBDA(scalar_t a) -> scalar_t {
return a > scalar_t(0) ? a * poscoef : (static_cast<scalar_t>(std::exp(a * negiptcoef)) - scalar_t(1.)) * negcoef;
});
auto negcoef = alpha.to<scalar_t>() * scale.to<scalar_t>();
auto poscoef = scale.to<scalar_t>();
auto negiptcoef = input_scale.to<scalar_t>();
gpu_kernel(iter, [negcoef, poscoef, negiptcoef]GPU_LAMBDA(scalar_t a) -> scalar_t {
return a > scalar_t(0) ? a * poscoef : (static_cast<scalar_t>(std::exp(a * negiptcoef)) - scalar_t(1.)) * negcoef;
});
});
}

void elu_backward_kernel(TensorIterator& iter, Scalar alpha, Scalar scale, Scalar input_scale) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "elu_backward_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "elu_backward_cuda", [&] {
auto negcoef = alpha.to<scalar_t>() * scale.to<scalar_t>();
auto poscoef = scale.to<scalar_t>();
auto negiptcoef = input_scale.to<scalar_t>();
gpu_kernel(iter, [negcoef, poscoef, negiptcoef]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return b <= scalar_t(0) ? a * negiptcoef * (b + negcoef) : a * poscoef;
});
auto negcoef = alpha.to<scalar_t>() * scale.to<scalar_t>();
auto poscoef = scale.to<scalar_t>();
auto negiptcoef = input_scale.to<scalar_t>();
gpu_kernel(iter, [negcoef, poscoef, negiptcoef]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return b <= scalar_t(0) ? a * negiptcoef * (b + negcoef) : a * poscoef;
});
});
}
Expand Down Expand Up @@ -387,22 +371,18 @@ void GeluBackwardCUDAKernelImpl(TensorIterator& it) {

void leaky_relu_kernel(TensorIterator& iter, Scalar negval_) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "leaky_relu_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "leaky_relu_cuda", [&] {
auto negval = negval_.to<scalar_t>();
gpu_kernel(iter, [negval]GPU_LAMBDA(scalar_t a) -> scalar_t {
return a > scalar_t(0) ? a : a * negval;
});
auto negval = negval_.to<scalar_t>();
gpu_kernel(iter, [negval]GPU_LAMBDA(scalar_t a) -> scalar_t {
return a > scalar_t(0) ? a : a * negval;
});
});
}

void leaky_relu_backward_kernel(TensorIterator& iter, Scalar negval_) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "leaky_relu_backward_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "leaky_relu_backward_cuda", [&] {
auto negval = negval_.to<scalar_t>();
gpu_kernel(iter, [negval]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a > scalar_t(0) ? b : b * negval;
});
auto negval = negval_.to<scalar_t>();
gpu_kernel(iter, [negval]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a > scalar_t(0) ? b : b * negval;
});
});
}
Expand Down
1 change: 0 additions & 1 deletion test/test_nn.py
Expand Up @@ -11832,7 +11832,6 @@ def _test_bfloat16_ops(self, op, device, inp_dims=(), prec=1e-2):
self.assertEqual(input1.grad.data, input2.grad.data, atol=prec, rtol=0, exact_dtype=False)

@onlyCUDA
@skipCUDAIfNotRocm
def test_activations_bfloat16(self, device):
self._test_bfloat16_ops(torch.nn.ReLU(), device, inp_dims=(5), prec=1e-2)
self._test_bfloat16_ops(torch.nn.Threshold(0.1, 20), device, inp_dims=(5), prec=1e-2)
Expand Down

0 comments on commit 0638940

Please sign in to comment.