Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add division overload with rounding_mode selection #50280

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c712bbc
Add division overload with rounding_mode selection
peterbell10 Jan 8, 2021
5e8b4a7
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 8, 2021
8bbfdfa
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 8, 2021
c582a55
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 8, 2021
6c6bda9
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 8, 2021
0687fe7
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 8, 2021
af02f1c
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 9, 2021
a838dd9
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 12, 2021
496cb93
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 12, 2021
c28ef0f
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 13, 2021
78f46ad
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 13, 2021
2617933
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 13, 2021
983d643
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 15, 2021
703e0b7
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 15, 2021
8784e96
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 16, 2021
2af43d7
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 18, 2021
69aac40
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 20, 2021
8529096
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 28, 2021
7e7b1d3
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 28, 2021
84f755f
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 28, 2021
61750ba
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 28, 2021
bb33706
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 29, 2021
6bd0e9f
Update on "Add division overload with rounding_mode selection"
peterbell10 Feb 1, 2021
0220e1c
Update on "Add division overload with rounding_mode selection"
peterbell10 Feb 2, 2021
2fcb3a5
Update on "Add division overload with rounding_mode selection"
peterbell10 Feb 2, 2021
71b0cfe
trial fix for mobile manifest issue
Feb 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions aten/src/ATen/BatchingRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,12 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
BINARY_POINTWISE_VA(rsub, Scalar);
BINARY_POINTWISE(mul);
BINARY_POINTWISE(div);
{
using Binop = Tensor (*)(const Tensor&, const Tensor&, std::string);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rzou would you take a look here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry I think you got the wrong user. Was that meant to be @zou3519?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was, thanks @peterbell10. Darn autocomplete!

cc @zou3519

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this lgtm!

using Unop = Tensor (*)(const Tensor&, Scalar, std::string);
m.impl("div.Tensor_mode", binary_pointwise_batching_rule<Binop, at::div, std::string>);
m.impl("div.Scalar_mode", unwrap_and_call<Unop, at::div, Scalar, std::string>);
}

// at::pow has three out-of-place overloads
m.impl("pow.Tensor_Tensor", binary_pointwise_batching_rule<TensorTensorType, at::pow>);
Expand Down
114 changes: 89 additions & 25 deletions aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ DEFINE_DISPATCH(add_stub);
DEFINE_DISPATCH(add_clamp_stub);
DEFINE_DISPATCH(sub_stub);
DEFINE_DISPATCH(mul_stub);
DEFINE_DISPATCH(div_stub);
DEFINE_DISPATCH(div_true_stub);
DEFINE_DISPATCH(div_floor_stub);
DEFINE_DISPATCH(div_trunc_stub);
DEFINE_DISPATCH(remainder_stub);
DEFINE_DISPATCH(atan2_stub);
DEFINE_DISPATCH(bitwise_and_stub);
Expand Down Expand Up @@ -148,21 +150,45 @@ Tensor& copysign_(Tensor& self, Scalar other) {
return native::copysign_(self, wrapped_scalar_tensor(other));
}

Tensor& div_out(const Tensor& self, const Tensor& other, Tensor& result) {
Tensor& div_true_out(const Tensor& self, const Tensor& other, Tensor& result) {
auto iter = TensorIterator::binary_float_op(result, self, other);
div_stub(iter.device_type(), iter);
div_true_stub(iter.device_type(), iter);
if (!result.defined()) {
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved
result = iter.output();
}
return result;
}

Tensor& div_trunc_out(const Tensor& self, const Tensor& other, Tensor& result) {
auto iter = TensorIterator::binary_op(result, self, other);
div_trunc_stub(iter.device_type(), iter);
if (!result.defined()) {
result = iter.output();
}
return result;
}

Tensor& div_floor_out(const Tensor& self, const Tensor& other, Tensor& result) {
auto iter = TensorIterator::binary_op(result, self, other);
div_floor_stub(iter.device_type(), iter);
if (!result.defined()) {
result = iter.output();
}
return result;
}

Tensor& div_out(const Tensor& self, const Tensor& other, Tensor& result) {
return div_true_out(self, other, result);
}

Tensor div(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_float_op(result, self, other);
div_stub(iter.device_type(), iter);
return iter.output();
div_true_out(self, other, result);
return result;
}

Tensor& div_(Tensor& self, const Tensor& other) {
return native::div_out(self, other, self);
return div_true_out(self, other, self);
}

// WARNING: There doesn't appear to be any testing for this function
Expand All @@ -179,6 +205,38 @@ Tensor& div_(Tensor& self, Scalar other) {
return self.div_(wrapped_scalar_tensor(other)); // redispatch!
}

Tensor& div_out(const Tensor& self, const Tensor& other, std::string rounding_mode, Tensor& result) {
if (rounding_mode == "true") {
return div_true_out(self, other, result);
} else if (rounding_mode == "trunc") {
return div_trunc_out(self, other, result);
} else if (rounding_mode == "floor") {
return div_floor_out(self, other, result);
}

AT_ERROR("div expected rounding_mode to be one of 'true', 'trunc', or 'floor' "
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved
"but found '", rounding_mode, "'");
}

Tensor div(const Tensor& self, const Tensor& other, std::string rounding_mode) {
Tensor result;
native::div_out(self, other, std::move(rounding_mode), result);
TORCH_INTERNAL_ASSERT(result.defined());
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved
return result;
}

Tensor& div_(Tensor& self, const Tensor& other, std::string rounding_mode) {
return native::div_out(self, other, std::move(rounding_mode), self);
}

Tensor div(const Tensor& self, Scalar other, std::string rounding_mode) {
return self.div(wrapped_scalar_tensor(other), std::move(rounding_mode)); // redispatch!
}

Tensor& div_(Tensor& self, Scalar other, std::string rounding_mode) {
return self.div_(wrapped_scalar_tensor(other), std::move(rounding_mode)); // redispatch!
}

// divide, alias for div
Tensor& divide_out(Tensor& result, const Tensor& self, const Tensor& other) {
return at::div_out(result, self, other);
Expand All @@ -200,6 +258,26 @@ Tensor& divide_(Tensor& self, Scalar other) {
return self.div_(other);
}

Tensor& divide_out(Tensor& result, const Tensor& self, const Tensor& other, std::string rounding_mode) {
return at::div_out(result, self, other, std::move(rounding_mode));
}

Tensor divide(const Tensor& self, const Tensor& other, std::string rounding_mode) {
return self.div(other, std::move(rounding_mode));
}

Tensor& divide_(Tensor& self, const Tensor& other, std::string rounding_mode) {
return self.div_(other, std::move(rounding_mode));
}

Tensor divide(const Tensor& self, Scalar other, std::string rounding_mode) {
return self.div(other, std::move(rounding_mode));
}

Tensor& divide_(Tensor& self, Scalar other, std::string rounding_mode) {
return self.div_(other, std::move(rounding_mode));
}

// true_divide, an alias for div
Tensor& true_divide_out(Tensor& result, const Tensor& self, const Tensor& divisor) {
return at::div_out(result, self, divisor);
Expand Down Expand Up @@ -239,28 +317,14 @@ Tensor& remainder_(Tensor& self, const Tensor& other) {
}

Tensor& floor_divide_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other);
div_stub(iter.device_type(), iter);

if (result.is_floating_point()) {
result.trunc_();
}

return result;
// FIXME: Not actually doing floor division
return div_trunc_out(self, other, result);
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved
}

Tensor floor_divide(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);

div_stub(iter.device_type(), iter);

auto out = iter.output();
if (out.is_floating_point()) {
out.trunc_();
}

return out;
native::floor_divide_out(result, self, other);
return result;
}

Tensor& floor_divide_(Tensor& self, const Tensor& other) {
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/native/BinaryOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
DECLARE_DISPATCH(binary_clamp_fn_alpha, add_clamp_stub);
DECLARE_DISPATCH(binary_fn_alpha, sub_stub);
DECLARE_DISPATCH(binary_fn, mul_stub);
DECLARE_DISPATCH(binary_fn, div_stub);
DECLARE_DISPATCH(binary_fn, div_true_stub);
DECLARE_DISPATCH(binary_fn, div_floor_stub);
DECLARE_DISPATCH(binary_fn, div_trunc_stub);
DECLARE_DISPATCH(binary_fn, remainder_stub);
DECLARE_DISPATCH(binary_fn, atan2_stub);
DECLARE_DISPATCH(binary_fn, bitwise_and_stub);
Expand Down
68 changes: 61 additions & 7 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,24 +92,76 @@ void mul_kernel(TensorIterator& iter) {
}
}

void div_kernel(TensorIterator& iter) {
if (isIntegralType(iter.dtype(), /*includeBool*/ false)) {
void div_true_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "div_true_cpu", [&]() {
cpu_kernel_vec(iter,
[](scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
return a / b;
},
[](Vec256<scalar_t> a, Vec256<scalar_t> b) {
return a / b;
});
});
}

void div_trunc_kernel(TensorIterator& iter) {
auto dtype = iter.common_dtype();
if (isIntegralType(dtype, /*includeBool*/ false)) {
// There's no SIMD integer division, so don't try to vectorize it.
// TODO: if the divisor is a scalar, rewrite as multiplication by a constant.
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "div_cpu", [&]() {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "div_trunc_cpu", [&]() {
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved
cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
TORCH_CHECK(b != 0, "ZeroDivisionError");
return a / b;
});
});
} else {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "div_cpu", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, dtype, "div_trunc_cpu", [&]() {
cpu_kernel_vec(iter,
[](scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
return std::trunc(a / b);
},
[](Vec256<scalar_t> a, Vec256<scalar_t> b) {
return (a / b).trunc();
});
});
}
}

void div_floor_kernel(TensorIterator& iter) {
const auto dtype = iter.common_dtype();
if (dtype == kByte) {
// In the special case of unsigned integer division, floor division is
// equivalent to truncation division (since the signs of the divisor and
// dividend are always the same)
return div_trunc_kernel(iter);
} else if (isIntegralType(dtype, /*includeBool*/ false)) {
// There's no SIMD integer division, so don't try to vectorize it.
// TODO: if the divisor is a scalar, rewrite as multiplication by a constant.
AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "div_floor_cpu", [&]() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is inconsistent between using dtype and iter.common_dtype().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed all uses of iter.dtype(). If instead you meant the variable dtype, then I would note that it's assigned from iter.common_dtype() above. Just a bit less to type.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize the value is the same, just for readability the code might want to stick to either dtype or iter.common_dtype(). No big deal either way.

cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {

TORCH_CHECK(b != 0, "ZeroDivisionError");
if ((a < 0) != (b < 0)) {
// Subtracts one from the results of truncation division if the
// divisor and dividend have different sign(bit)s and the remainder of
// the division is nonzero
const auto quot = a / b;
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved
const auto rem = a % b;
return rem ? quot - 1 : quot;
}

return a / b;
});
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "div_floor_cpu", [&]() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same dtype vs iter.common_dtype here, too.

cpu_kernel_vec(iter,
[](scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
return a / b;
return std::floor(a / b);
},
[](Vec256<scalar_t> a, Vec256<scalar_t> b) {
return a / b;
return (a / b).floor();
});
});
}
Expand Down Expand Up @@ -838,7 +890,9 @@ REGISTER_DISPATCH(add_stub, &add_kernel);
REGISTER_DISPATCH(add_clamp_stub, &add_clamp_kernel);
REGISTER_DISPATCH(sub_stub, &sub_kernel);
REGISTER_DISPATCH(mul_stub, &mul_kernel);
REGISTER_DISPATCH(div_stub, &div_kernel);
REGISTER_DISPATCH(div_true_stub, &div_true_kernel);
REGISTER_DISPATCH(div_trunc_stub, &div_trunc_kernel);
REGISTER_DISPATCH(div_floor_stub, &div_floor_kernel);
REGISTER_DISPATCH(remainder_stub, &remainder_kernel);
REGISTER_DISPATCH(atan2_stub, &atan2_kernel);
REGISTER_DISPATCH(bitwise_and_stub, &bitwise_and_kernel);
Expand Down
86 changes: 81 additions & 5 deletions aten/src/ATen/native/cuda/BinaryMulDivKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,100 @@ struct MulFunctor<bool> {
};


void div_kernel_cuda(TensorIterator& iter) {
if (!isIntegralType(iter.common_dtype(), /*includeBool*/ false) && iter.is_cpu_scalar(2)) {
void div_true_kernel_cuda(TensorIterator& iter) {
if (iter.is_cpu_scalar(2)) {
// optimization for floating-point types: if the second operand is a CPU
// scalar, compute a * reciprocal(b). Note that this may lose one bit of
// precision compared to computing the division.
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "div_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "div_true_cuda", [&]() {
using accscalar_t = at::acc_type<scalar_t, true>;
auto inv_b = accscalar_t(1.0) / iter.scalar_value<accscalar_t>(2);
iter.remove_operand(2);
MulScalarFunctor<scalar_t, decltype(inv_b)> f(inv_b);
gpu_kernel(iter, f);
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, iter.common_dtype(), "div_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "div_true_cuda", [&]() {
DivFunctor<scalar_t> f;
gpu_kernel_with_scalars(iter, f);
});
}
}

void div_trunc_kernel_cuda(TensorIterator& iter) {
auto dtype = iter.common_dtype();
if (isIntegralType(dtype, /*includeBool*/ false)) {
AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_trunc_cuda", [&]() {
gpu_kernel_with_scalars(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
return a / b;
});
});
} else if (iter.is_cpu_scalar(2)) {
// optimization for floating-point types: if the second operand is a CPU
// scalar, compute a * reciprocal(b). Note that this may lose one bit of
// precision compared to computing the division.
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, dtype, "div_trunc_cuda", [&]() {
using accscalar_t = at::acc_type<scalar_t, true>;
auto inv_b = accscalar_t(1.0) / iter.scalar_value<accscalar_t>(2);
mruberry marked this conversation as resolved.
Show resolved Hide resolved
iter.remove_operand(2);
gpu_kernel(iter, [inv_b] GPU_LAMBDA (scalar_t a) -> scalar_t {
return std::trunc(a * inv_b);
});
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, dtype, "div_trunc_cuda", [&]() {
gpu_kernel_with_scalars(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
return std::trunc(a / b);
});
});
}
}

void div_floor_kernel_cuda(TensorIterator& iter) {
const auto dtype = iter.common_dtype();
if (dtype == kByte) {
// In the special case of unsigned integer division, floor division is
// equivalent to truncation division (since the signs of the divisor and
// dividend are always the same)
return div_trunc_kernel_cuda(iter);
} else if (isIntegralType(dtype, /*includeBool*/ false)) {
// There's no SIMD integer division, so don't try to vectorize it.
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved
// TODO: if the divisor is a scalar, rewrite as multiplication by a constant.
AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "div_floor_cuda", [&]() {
gpu_kernel_with_scalars(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
if ((a < 0) != (b < 0)) {
// Subtracts one from the results of truncation division if the
// divisor and dividend have different sign(bit)s and the remainder of
// the division is nonzero
const auto quot = a / b;
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved
const auto rem = a % b;
return rem ? quot - 1 : quot;
}

return a / b;
});
});
} else if (iter.is_cpu_scalar(2)) {
// optimization for floating-point types: if the second operand is a CPU
// scalar, compute a * reciprocal(b). Note that this may lose one bit of
// precision compared to computing the division.
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, dtype, "div_floor_cuda", [&]() {
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved
using accscalar_t = at::acc_type<scalar_t, true>;
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved
auto inv_b = accscalar_t(1.0) / iter.scalar_value<accscalar_t>(2);
iter.remove_operand(2);
gpu_kernel(iter, [inv_b] GPU_LAMBDA (scalar_t a) -> scalar_t {
return std::floor(a * inv_b);
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved
});
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, dtype, "div_floor_cuda", [&]() {
gpu_kernel_with_scalars(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
return std::floor(a / b);
});
});
}
}

void mul_kernel_cuda(TensorIterator& iter) {
if (!isIntegralType(iter.common_dtype(), /*includeBool*/ true) &&
(iter.is_cpu_scalar(1) || iter.is_cpu_scalar(2))) {
Expand All @@ -86,7 +160,9 @@ void mul_kernel_cuda(TensorIterator& iter) {
}
}

REGISTER_DISPATCH(div_stub, &div_kernel_cuda);
REGISTER_DISPATCH(div_true_stub, &div_true_kernel_cuda);
REGISTER_DISPATCH(div_trunc_stub, &div_trunc_kernel_cuda);
REGISTER_DISPATCH(div_floor_stub, &div_floor_kernel_cuda);
REGISTER_DISPATCH(mul_stub, &mul_kernel_cuda);

}} // namespace at::native