Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add division overload with rounding_mode selection #50280

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c712bbc
Add division overload with rounding_mode selection
peterbell10 Jan 8, 2021
5e8b4a7
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 8, 2021
8bbfdfa
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 8, 2021
c582a55
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 8, 2021
6c6bda9
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 8, 2021
0687fe7
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 8, 2021
af02f1c
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 9, 2021
a838dd9
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 12, 2021
496cb93
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 12, 2021
c28ef0f
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 13, 2021
78f46ad
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 13, 2021
2617933
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 13, 2021
983d643
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 15, 2021
703e0b7
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 15, 2021
8784e96
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 16, 2021
2af43d7
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 18, 2021
69aac40
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 20, 2021
8529096
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 28, 2021
7e7b1d3
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 28, 2021
84f755f
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 28, 2021
61750ba
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 28, 2021
bb33706
Update on "Add division overload with rounding_mode selection"
peterbell10 Jan 29, 2021
6bd0e9f
Update on "Add division overload with rounding_mode selection"
peterbell10 Feb 1, 2021
0220e1c
Update on "Add division overload with rounding_mode selection"
peterbell10 Feb 2, 2021
2fcb3a5
Update on "Add division overload with rounding_mode selection"
peterbell10 Feb 2, 2021
71b0cfe
trial fix for mobile manifest issue
Feb 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions aten/src/ATen/BatchingRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,12 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
BINARY_POINTWISE_VA(rsub, Scalar);
BINARY_POINTWISE(mul);
BINARY_POINTWISE(div);
{
using Binop = Tensor (*)(const Tensor&, const Tensor&, std::string);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rzou would you take a look here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry I think you got the wrong user. Was that meant to be @zou3519?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was, thanks @peterbell10. Darn autocomplete!

cc @zou3519

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this lgtm!

using Unop = Tensor (*)(const Tensor&, Scalar, std::string);
m.impl("div.Tensor_mode", binary_pointwise_batching_rule<Binop, at::div, std::string>);
m.impl("div.Scalar_mode", unwrap_and_call<Unop, at::div, Scalar, std::string>);
}

// at::pow has three out-of-place overloads
m.impl("pow.Tensor_Tensor", binary_pointwise_batching_rule<TensorTensorType, at::pow>);
Expand Down
12 changes: 10 additions & 2 deletions aten/src/ATen/cpu/vec256/vec256_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@ namespace at {
namespace vec256 {
// See Note [Acceptable use of anonymous namespace in header]
namespace {
// at::Half should be treated as floating point
// at::Half and at::BFloat16 should be treated as floating point
template <typename T>
struct is_floating_point:
std::integral_constant<bool,
std::is_floating_point<T>::value ||
std::is_same<T, at::Half>::value> {
std::is_same<T, at::Half>::value ||
std::is_same<T, at::BFloat16>::value> {
};

template<size_t n> struct int_of_size;
Expand Down Expand Up @@ -316,6 +317,13 @@ struct Vec256 {
}
return ret;
}
Vec256<T> copysign(const Vec256<T> &sign) const {
Vec256<T> ret;
for (int64_t i = 0; i < size(); i++) {
ret[i] = std::copysign(values[i], sign[i]);
}
return ret;
}
Vec256<T> erf() const {
return map(std::erf);
}
Expand Down
9 changes: 9 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,15 @@ template <> class Vec256<BFloat16> {
auto o2 = Sleef_atan2f8_u10(hi, b2);
return cvtfp32_bf16(o1, o2);
}
Vec256<BFloat16> copysign(const Vec256<BFloat16> &sign) const {
// copy sign bit (0x8000) from sign and remaining bits from values
__m256i mask_value = _mm256_set1_epi32(~0x80008000);
__m256i mask_signbit = _mm256_set1_epi32(0x80008000);
return Vec256<BFloat16>(
_mm256_or_si256(
_mm256_and_si256(values, mask_value),
_mm256_and_si256(sign, mask_signbit)));
}
Vec256<BFloat16> erf() const {
return map(Sleef_erff8_u10);
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ template <> class Vec256<double> {
Vec256<double> atan2(const Vec256<double> &b) const {
return Vec256<double>(Sleef_atan2d4_u10(values, b));
}
Vec256<double> copysign(const Vec256<double> &sign) const {
return Vec256<double>(Sleef_copysignd4(values, sign));
}
Vec256<double> erf() const {
return Vec256<double>(Sleef_erfd4_u10(values));
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ template <> class Vec256<float> {
Vec256<float> atan2(const Vec256<float> &b) const {
return Vec256<float>(Sleef_atan2f8_u10(values, b));
}
Vec256<float> copysign(const Vec256<float> &sign) const {
return Vec256<float>(Sleef_copysignf8(values, sign));
}
Vec256<float> erf() const {
return Vec256<float>(Sleef_erff8_u10(values));
}
Expand Down
114 changes: 89 additions & 25 deletions aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ DEFINE_DISPATCH(add_stub);
DEFINE_DISPATCH(add_clamp_stub);
DEFINE_DISPATCH(sub_stub);
DEFINE_DISPATCH(mul_stub);
DEFINE_DISPATCH(div_stub);
DEFINE_DISPATCH(div_true_stub);
DEFINE_DISPATCH(div_floor_stub);
DEFINE_DISPATCH(div_trunc_stub);
DEFINE_DISPATCH(remainder_stub);
DEFINE_DISPATCH(atan2_stub);
DEFINE_DISPATCH(bitwise_and_stub);
Expand Down Expand Up @@ -150,21 +152,45 @@ Tensor& copysign_(Tensor& self, Scalar other) {
return native::copysign_(self, wrapped_scalar_tensor(other));
}

Tensor& div_out(const Tensor& self, const Tensor& other, Tensor& result) {
Tensor& div_true_out(const Tensor& self, const Tensor& other, Tensor& result) {
auto iter = TensorIterator::binary_float_op(result, self, other);
div_stub(iter.device_type(), iter);
div_true_stub(iter.device_type(), iter);
if (!result.defined()) {
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved
result = iter.output();
}
return result;
}

Tensor& div_trunc_out(const Tensor& self, const Tensor& other, Tensor& result) {
auto iter = TensorIterator::binary_op(result, self, other);
div_trunc_stub(iter.device_type(), iter);
if (!result.defined()) {
result = iter.output();
}
return result;
}

Tensor& div_floor_out(const Tensor& self, const Tensor& other, Tensor& result) {
auto iter = TensorIterator::binary_op(result, self, other);
div_floor_stub(iter.device_type(), iter);
if (!result.defined()) {
result = iter.output();
}
return result;
}

Tensor& div_out(const Tensor& self, const Tensor& other, Tensor& result) {
return div_true_out(self, other, result);
}

Tensor div(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_float_op(result, self, other);
div_stub(iter.device_type(), iter);
return iter.output();
div_true_out(self, other, result);
return result;
}

Tensor& div_(Tensor& self, const Tensor& other) {
return native::div_out(self, other, self);
return div_true_out(self, other, self);
}

// WARNING: There doesn't appear to be any testing for this function
Expand All @@ -181,6 +207,38 @@ Tensor& div_(Tensor& self, Scalar other) {
return self.div_(wrapped_scalar_tensor(other)); // redispatch!
}

Tensor& div_out(const Tensor& self, const Tensor& other, std::string rounding_mode, Tensor& result) {
if (rounding_mode == "true") {
return div_true_out(self, other, result);
} else if (rounding_mode == "trunc") {
return div_trunc_out(self, other, result);
} else if (rounding_mode == "floor") {
return div_floor_out(self, other, result);
}

TORCH_CHECK(false,
"div expected rounding_mode to be one of 'true', 'trunc', or 'floor' "
"but found '", rounding_mode, "'");
}

Tensor div(const Tensor& self, const Tensor& other, std::string rounding_mode) {
Tensor result;
native::div_out(self, other, std::move(rounding_mode), result);
return result;
}

Tensor& div_(Tensor& self, const Tensor& other, std::string rounding_mode) {
return native::div_out(self, other, std::move(rounding_mode), self);
}

Tensor div(const Tensor& self, Scalar other, std::string rounding_mode) {
return self.div(wrapped_scalar_tensor(other), std::move(rounding_mode)); // redispatch!
}

Tensor& div_(Tensor& self, Scalar other, std::string rounding_mode) {
return self.div_(wrapped_scalar_tensor(other), std::move(rounding_mode)); // redispatch!
}

// divide, alias for div
Tensor& divide_out(Tensor& result, const Tensor& self, const Tensor& other) {
return at::div_out(result, self, other);
Expand All @@ -202,6 +260,26 @@ Tensor& divide_(Tensor& self, Scalar other) {
return self.div_(other);
}

Tensor& divide_out(const Tensor& self, const Tensor& other, std::string rounding_mode, Tensor& result) {
return at::div_out(result, self, other, std::move(rounding_mode));
}

Tensor divide(const Tensor& self, const Tensor& other, std::string rounding_mode) {
return self.div(other, std::move(rounding_mode));
}

Tensor& divide_(Tensor& self, const Tensor& other, std::string rounding_mode) {
return self.div_(other, std::move(rounding_mode));
}

Tensor divide(const Tensor& self, Scalar other, std::string rounding_mode) {
return self.div(other, std::move(rounding_mode));
}

Tensor& divide_(Tensor& self, Scalar other, std::string rounding_mode) {
return self.div_(other, std::move(rounding_mode));
}

// true_divide, an alias for div
Tensor& true_divide_out(Tensor& result, const Tensor& self, const Tensor& divisor) {
return at::div_out(result, self, divisor);
Expand Down Expand Up @@ -241,28 +319,14 @@ Tensor& remainder_(Tensor& self, const Tensor& other) {
}

Tensor& floor_divide_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other);
div_stub(iter.device_type(), iter);

if (result.is_floating_point()) {
result.trunc_();
}

return result;
// FIXME: Not actually doing floor division (#43874)
return div_trunc_out(self, other, result);
}

Tensor floor_divide(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);

div_stub(iter.device_type(), iter);

auto out = iter.output();
if (out.is_floating_point()) {
out.trunc_();
}

return out;
native::floor_divide_out(result, self, other);
return result;
}

Tensor& floor_divide_(Tensor& self, const Tensor& other) {
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/native/BinaryOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
DECLARE_DISPATCH(binary_clamp_fn_alpha, add_clamp_stub);
DECLARE_DISPATCH(binary_fn_alpha, sub_stub);
DECLARE_DISPATCH(binary_fn, mul_stub);
DECLARE_DISPATCH(binary_fn, div_stub);
DECLARE_DISPATCH(binary_fn, div_true_stub);
DECLARE_DISPATCH(binary_fn, div_floor_stub);
DECLARE_DISPATCH(binary_fn, div_trunc_stub);
DECLARE_DISPATCH(binary_fn, remainder_stub);
DECLARE_DISPATCH(binary_fn, atan2_stub);
DECLARE_DISPATCH(binary_fn, bitwise_and_stub);
Expand Down
103 changes: 96 additions & 7 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,29 +92,116 @@ void mul_kernel(TensorIterator& iter) {
}
}

void div_kernel(TensorIterator& iter) {
if (isIntegralType(iter.dtype(), /*includeBool*/ false)) {
void div_true_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "div_true_cpu", [&]() {
cpu_kernel_vec(iter,
[](scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
return a / b;
},
[](Vec256<scalar_t> a, Vec256<scalar_t> b) {
return a / b;
});
});
}

void div_trunc_kernel(TensorIterator& iter) {
const auto dtype = iter.common_dtype();
if (isIntegralType(dtype, /*includeBool*/ false)) {
// There's no SIMD integer division, so don't try to vectorize it.
// TODO: if the divisor is a scalar, rewrite as multiplication by a constant.
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "div_cpu", [&]() {
AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_trunc_cpu", [&]() {
cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
TORCH_CHECK(b != 0, "ZeroDivisionError");
return a / b;
});
});
} else {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "div_cpu", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, dtype, "div_trunc_cpu", [&]() {
cpu_kernel_vec(iter,
[](scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
return a / b;
return std::trunc(a / b);
},
[](Vec256<scalar_t> a, Vec256<scalar_t> b) {
return a / b;
return (a / b).trunc();
});
});
}
}

// NOTE: [Floor Division in Python]
// Python's __floordiv__ operator is more complicated than just floor(a / b).
// It aims to maintain the property: a == (a // b) * b + remainder(a, b)
// which can otherwise fail due to rounding errors in the remainder.
// So, instead it is calculated as: a // b = (a - remainder(a, b)) / b
// With some additional fix-ups added to the result.
//
// For reference, see CPython's implementation:
// https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636

void div_floor_kernel(TensorIterator& iter) {
const auto dtype = iter.common_dtype();
if (dtype == kByte) {
// In the special case of unsigned integer division, floor division is
// equivalent to truncation division (since the signs of the divisor and
// dividend are always the same)
return div_trunc_kernel(iter);
} else if (isIntegralType(dtype, /*includeBool*/ false)) {
// There's no SIMD integer division, so don't try to vectorize it.
AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_floor_cpu", [&]() {
cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
TORCH_CHECK(b != 0, "ZeroDivisionError");
if ((a < 0) != (b < 0)) {
// Subtracts one from the results of truncation division if the
// divisor and dividend have different sign(bit)s and the remainder of
// the division is nonzero
const auto quot = a / b;
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved
const auto rem = a % b;
return rem ? quot - 1 : quot;
}

return a / b;
});
});
} else {
// See NOTE: [Floor Division in Python]
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, dtype, "div_floor_cpu", [&]() {
using vec_t = Vec256<scalar_t>;
cpu_kernel_vec(iter,
[](scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
auto mod = std::fmod(a, b);
auto div = (a - mod) / b;
if ((mod != 0) && (b < 0) != (mod < 0)) {
div -= scalar_t(1);
}

scalar_t floordiv;
if (div != 0) {
floordiv = std::floor(div);
if (div - floordiv > scalar_t(0.5)) {
floordiv += scalar_t(1.0);
}
} else {
floordiv = std::copysign(scalar_t(0), a / b);
}
return floordiv;
},
[](Vec256<scalar_t> a, Vec256<scalar_t> b) -> Vec256<scalar_t>{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is triggering some internal build issues. Adding a vectorized function can be a little tricky because we often have to stub them out on some platforms, like Android.

Since we're so close to the branch cut, I propose removing the copysign implementation and this vectorized implementation. We can file an issue and add them back in a later PR where we can take our time and focus on that issue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry this should be good now. Removed all Vec256 changes and unvectorized floor_divide.

using vec_t = Vec256<scalar_t>;
auto mod = a.fmod(b);
auto div = (a - mod) / b;
const auto zero = vec_t(0);
auto mask = (mod != zero) & ((b < zero) ^ (mod < zero));
const auto one = vec_t(1);
div = vec_t::blendv(div, div - one, mask);
auto floordiv = div.floor();
mask = (div - floordiv) > vec_t(0.5);
floordiv = vec_t::blendv(floordiv, floordiv + one, mask);
return vec_t::blendv(floordiv, zero.copysign(a / b), div == zero);
});
});
}
}

void remainder_kernel(TensorIterator& iter) {
if (isIntegralType(iter.common_dtype(), /*includeBool*/ false)) {
AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "remainder_cpu", [&]() {
Expand Down Expand Up @@ -861,7 +948,9 @@ REGISTER_DISPATCH(add_stub, &add_kernel);
REGISTER_DISPATCH(add_clamp_stub, &add_clamp_kernel);
REGISTER_DISPATCH(sub_stub, &sub_kernel);
REGISTER_DISPATCH(mul_stub, &mul_kernel);
REGISTER_DISPATCH(div_stub, &div_kernel);
REGISTER_DISPATCH(div_true_stub, &div_true_kernel);
REGISTER_DISPATCH(div_trunc_stub, &div_trunc_kernel);
REGISTER_DISPATCH(div_floor_stub, &div_floor_kernel);
REGISTER_DISPATCH(remainder_stub, &remainder_kernel);
REGISTER_DISPATCH(atan2_stub, &atan2_kernel);
REGISTER_DISPATCH(bitwise_and_stub, &bitwise_and_kernel);
Expand Down