diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index 29c51a107b29..113723506c98 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -13,6 +13,7 @@ namespace at { namespace native { DEFINE_DISPATCH(add_stub); +DEFINE_DISPATCH(add_clamp_stub); DEFINE_DISPATCH(sub_stub); DEFINE_DISPATCH(mul_stub); DEFINE_DISPATCH(div_stub); @@ -62,6 +63,53 @@ Tensor& add_(Tensor& self, const Tensor& other, Scalar alpha) { return native::add_out(self, self, other, alpha); } +Tensor& add_relu_impl( + Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { + auto iter = TensorIterator::binary_op(result, self, other, + /*check_mem_overlap=*/true); + Scalar min_val; + Scalar max_val; + if (self.dtype() == at::kInt) { + min_val = 0; + max_val = std::numeric_limits::max(); + } else if (self.dtype() == at::kLong) { + min_val = 0; + max_val = std::numeric_limits::max(); + } else if (self.dtype() == at::kShort) { + min_val = 0; + max_val = std::numeric_limits::max(); + } else if (self.dtype() == at::kChar) { + min_val = 0; + max_val = std::numeric_limits::max(); + } else if (self.dtype() == at::kFloat) { + min_val = 0.0; + max_val = std::numeric_limits::max(); + } else if (self.dtype() == at::kDouble) { + min_val = 0.0; + max_val = std::numeric_limits::max(); + } else { + TORCH_INTERNAL_ASSERT( + "Unsupported datatype for add_relu:", self.dtype().name()); + } + + result = iter.output(); + add_clamp_stub(iter.device_type(), iter, alpha, min_val, max_val); + return result; +} + +Tensor& add_relu_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { + return add_relu_impl(result, self, other, alpha); +} + +Tensor add_relu(const Tensor& self, const Tensor& other, Scalar alpha) { + Tensor result; + return add_relu_impl(result, self, other, alpha); +} + +Tensor& add_relu_(Tensor& self, const Tensor& other, Scalar alpha) { + return add_relu_impl(self, self, other, alpha); +} + Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) { if (isIntegralType(result.scalar_type(), /*includeBool=*/ true)) { TORCH_CHECK(false, diff --git a/aten/src/ATen/native/BinaryOps.h b/aten/src/ATen/native/BinaryOps.h index 86d271a43378..850ba89c2c3f 100644 --- a/aten/src/ATen/native/BinaryOps.h +++ b/aten/src/ATen/native/BinaryOps.h @@ -26,8 +26,11 @@ inline void sub_check(const Tensor& self, const Tensor& other) { using binary_fn_alpha = void(*)(TensorIterator&, Scalar alpha); using binary_fn = void(*)(TensorIterator&); +using binary_clamp_fn_alpha = + void(*)(TensorIterator&, Scalar alpha, Scalar min_val, Scalar max_val); DECLARE_DISPATCH(binary_fn_alpha, add_stub); +DECLARE_DISPATCH(binary_clamp_fn_alpha, add_clamp_stub); DECLARE_DISPATCH(binary_fn_alpha, sub_stub); DECLARE_DISPATCH(binary_fn, mul_stub); DECLARE_DISPATCH(binary_fn, div_stub); diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 43af4182609e..d5a5b744ecaa 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -35,6 +35,27 @@ void add_kernel(TensorIterator& iter, Scalar alpha_scalar) { } } +void add_clamp_kernel(TensorIterator& iter, Scalar alpha_scalar, Scalar min_val, Scalar max_val) { + AT_DISPATCH_ALL_TYPES(iter.dtype(), "add_clamp_cpu", [&]() { + auto alpha = alpha_scalar.to(); + auto alpha_vec = Vec256(alpha); + auto min_scalar = min_val.to(); + auto min_vec = Vec256(min_scalar); + auto max_scalar = max_val.to(); + auto max_vec = Vec256(max_scalar); + cpu_kernel_vec(iter, + [=](scalar_t a, scalar_t b) __ubsan_ignore_undefined__ -> scalar_t { + return std::min(max_scalar, std::max(min_scalar, a + alpha * b)); + }, + [=](Vec256 a, Vec256 b) __ubsan_ignore_undefined__ { + auto add_clamp_res = vec256::fmadd(b, alpha_vec, a); + add_clamp_res = vec256::clamp_min(add_clamp_res, min_vec); + add_clamp_res = vec256::clamp_max(add_clamp_res, max_vec); + return add_clamp_res; + }); + }); +} + void atan2_kernel(TensorIterator& iter) { AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "atan2_cpu", [&]() { cpu_kernel_vec(iter, [=](scalar_t a, scalar_t b) -> scalar_t { @@ -644,6 +665,7 @@ void logaddexp2_kernel(TensorIterator& iter) { REGISTER_DISPATCH(add_stub, &add_kernel); +REGISTER_DISPATCH(add_clamp_stub, &add_clamp_kernel); REGISTER_DISPATCH(sub_stub, &sub_kernel); REGISTER_DISPATCH(mul_stub, &mul_kernel); REGISTER_DISPATCH(div_stub, &div_kernel); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index d833b01aabdf..071cd2a6fdd1 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -305,6 +305,25 @@ SparseCUDA: add_out_sparse_cuda MkldnnCPU: mkldnn_add_out +- func: add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + use_c10_dispatcher: full + variants: function + dispatch: + CPU: add_relu + supports_named_tensor: True + +- func: add_relu_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) + variants: function + dispatch: + CPU: add_relu_ + supports_named_tensor: True + +- func: add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + variants: function + dispatch: + CPU: add_relu_out + supports_named_tensor: True + # For C++ only, until we have conversion from C++ numbers to Tensor - func: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor use_c10_dispatcher: full diff --git a/test/test_nn.py b/test/test_nn.py index 08d3868f6954..fd021aa8c46e 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8716,6 +8716,21 @@ def test_fuse_module_eval_numerics(self, X, running_mean, running_var): self.assertEqual(Y_ref, Y_hat, msg="Conv+BN fusion results are off") +class TestAddRelu(TestCase): + def test_add_relu(self): + a = torch.rand((7, 11)) + b = torch.rand((7, 11)) + a = a.float() + b = b.float() + a = a * -10 + a = a + 5 + add_res = a + b + relu_res = torch.relu(add_res) + add_relu_res = torch.add_relu(a, b) + + self.assertTrue(torch.allclose(add_relu_res, relu_res)) + + def add_test(test, decorator=None): def add(test_name, fn): if hasattr(TestNN, test_name): diff --git a/torch/_overrides.py b/torch/_overrides.py index 6447abbcce97..ce95de055d6e 100644 --- a/torch/_overrides.py +++ b/torch/_overrides.py @@ -171,6 +171,7 @@ def get_testing_overrides(): torch.adaptive_max_pool1d: lambda inputs, output_size: -1, torch.acos: lambda input, out=None: -1, torch.acosh: lambda input, out=None: -1, + torch.add_relu: lambda input, other, out=None: -1, torch.add: lambda input, other, out=None: -1, torch.addbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1, torch.addcdiv: lambda input, tensor1, tensor2, value=1, out=None: -1,