Skip to content
Closed
48 changes: 48 additions & 0 deletions aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace at {
namespace native {

DEFINE_DISPATCH(add_stub);
DEFINE_DISPATCH(add_clamp_stub);
DEFINE_DISPATCH(sub_stub);
DEFINE_DISPATCH(mul_stub);
DEFINE_DISPATCH(div_stub);
Expand Down Expand Up @@ -62,6 +63,53 @@ Tensor& add_(Tensor& self, const Tensor& other, Scalar alpha) {
return native::add_out(self, self, other, alpha);
}

Tensor& add_relu_impl(
Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
Scalar min_val;
Scalar max_val;
if (self.dtype() == at::kInt) {
min_val = 0;
max_val = std::numeric_limits<int32_t>::max();
} else if (self.dtype() == at::kLong) {
min_val = 0;
max_val = std::numeric_limits<int64_t>::max();
} else if (self.dtype() == at::kShort) {
min_val = 0;
max_val = std::numeric_limits<int16_t>::max();
} else if (self.dtype() == at::kChar) {
min_val = 0;
max_val = std::numeric_limits<int8_t>::max();
} else if (self.dtype() == at::kFloat) {
min_val = 0.0;
max_val = std::numeric_limits<float>::max();
} else if (self.dtype() == at::kDouble) {
min_val = 0.0;
max_val = std::numeric_limits<double>::max();
} else {
TORCH_INTERNAL_ASSERT(
"Unsupported datatype for add_relu:", self.dtype().name());
}

result = iter.output();
add_clamp_stub(iter.device_type(), iter, alpha, min_val, max_val);
return result;
}

Tensor& add_relu_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
return add_relu_impl(result, self, other, alpha);
}

Tensor add_relu(const Tensor& self, const Tensor& other, Scalar alpha) {
Tensor result;
return add_relu_impl(result, self, other, alpha);
}

Tensor& add_relu_(Tensor& self, const Tensor& other, Scalar alpha) {
return add_relu_impl(self, self, other, alpha);
}

Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) {
if (isIntegralType(result.scalar_type(), /*includeBool=*/ true)) {
TORCH_CHECK(false,
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/BinaryOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ inline void sub_check(const Tensor& self, const Tensor& other) {

using binary_fn_alpha = void(*)(TensorIterator&, Scalar alpha);
using binary_fn = void(*)(TensorIterator&);
using binary_clamp_fn_alpha =
void(*)(TensorIterator&, Scalar alpha, Scalar min_val, Scalar max_val);

DECLARE_DISPATCH(binary_fn_alpha, add_stub);
DECLARE_DISPATCH(binary_clamp_fn_alpha, add_clamp_stub);
DECLARE_DISPATCH(binary_fn_alpha, sub_stub);
DECLARE_DISPATCH(binary_fn, mul_stub);
DECLARE_DISPATCH(binary_fn, div_stub);
Expand Down
22 changes: 22 additions & 0 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,27 @@ void add_kernel(TensorIterator& iter, Scalar alpha_scalar) {
}
}

void add_clamp_kernel(TensorIterator& iter, Scalar alpha_scalar, Scalar min_val, Scalar max_val) {
AT_DISPATCH_ALL_TYPES(iter.dtype(), "add_clamp_cpu", [&]() {
auto alpha = alpha_scalar.to<scalar_t>();
auto alpha_vec = Vec256<scalar_t>(alpha);
auto min_scalar = min_val.to<scalar_t>();
auto min_vec = Vec256<scalar_t>(min_scalar);
auto max_scalar = max_val.to<scalar_t>();
auto max_vec = Vec256<scalar_t>(max_scalar);
cpu_kernel_vec(iter,
[=](scalar_t a, scalar_t b) __ubsan_ignore_undefined__ -> scalar_t {
return std::min(max_scalar, std::max(min_scalar, a + alpha * b));
},
[=](Vec256<scalar_t> a, Vec256<scalar_t> b) __ubsan_ignore_undefined__ {
auto add_clamp_res = vec256::fmadd(b, alpha_vec, a);
add_clamp_res = vec256::clamp_min(add_clamp_res, min_vec);
add_clamp_res = vec256::clamp_max(add_clamp_res, max_vec);
return add_clamp_res;
});
});
}

void atan2_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "atan2_cpu", [&]() {
cpu_kernel_vec(iter, [=](scalar_t a, scalar_t b) -> scalar_t {
Expand Down Expand Up @@ -644,6 +665,7 @@ void logaddexp2_kernel(TensorIterator& iter) {


REGISTER_DISPATCH(add_stub, &add_kernel);
REGISTER_DISPATCH(add_clamp_stub, &add_clamp_kernel);
REGISTER_DISPATCH(sub_stub, &sub_kernel);
REGISTER_DISPATCH(mul_stub, &mul_kernel);
REGISTER_DISPATCH(div_stub, &div_kernel);
Expand Down
19 changes: 19 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,25 @@
SparseCUDA: add_out_sparse_cuda
MkldnnCPU: mkldnn_add_out

- func: add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
use_c10_dispatcher: full
variants: function
dispatch:
CPU: add_relu
supports_named_tensor: True

- func: add_relu_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
variants: function
dispatch:
CPU: add_relu_
supports_named_tensor: True

- func: add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
CPU: add_relu_out
supports_named_tensor: True

# For C++ only, until we have conversion from C++ numbers to Tensor
- func: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
use_c10_dispatcher: full
Expand Down
15 changes: 15 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8716,6 +8716,21 @@ def test_fuse_module_eval_numerics(self, X, running_mean, running_var):
self.assertEqual(Y_ref, Y_hat, msg="Conv+BN fusion results are off")


class TestAddRelu(TestCase):
def test_add_relu(self):
a = torch.rand((7, 11))
b = torch.rand((7, 11))
a = a.float()
b = b.float()
a = a * -10
a = a + 5
add_res = a + b
relu_res = torch.relu(add_res)
add_relu_res = torch.add_relu(a, b)

self.assertTrue(torch.allclose(add_relu_res, relu_res))


def add_test(test, decorator=None):
def add(test_name, fn):
if hasattr(TestNN, test_name):
Expand Down
1 change: 1 addition & 0 deletions torch/_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def get_testing_overrides():
torch.adaptive_max_pool1d: lambda inputs, output_size: -1,
torch.acos: lambda input, out=None: -1,
torch.acosh: lambda input, out=None: -1,
torch.add_relu: lambda input, other, out=None: -1,
torch.add: lambda input, other, out=None: -1,
torch.addbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1,
torch.addcdiv: lambda input, tensor1, tensor2, value=1, out=None: -1,
Expand Down