-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add division overload with rounding_mode selection #50280
Changes from 7 commits
c712bbc
5e8b4a7
8bbfdfa
c582a55
6c6bda9
0687fe7
af02f1c
a838dd9
496cb93
c28ef0f
78f46ad
2617933
983d643
703e0b7
8784e96
2af43d7
69aac40
8529096
7e7b1d3
84f755f
61750ba
bb33706
6bd0e9f
0220e1c
2fcb3a5
71b0cfe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -92,24 +92,76 @@ void mul_kernel(TensorIterator& iter) { | |
} | ||
} | ||
|
||
void div_kernel(TensorIterator& iter) { | ||
if (isIntegralType(iter.dtype(), /*includeBool*/ false)) { | ||
void div_true_kernel(TensorIterator& iter) { | ||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "div_true_cpu", [&]() { | ||
cpu_kernel_vec(iter, | ||
[](scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { | ||
return a / b; | ||
}, | ||
[](Vec256<scalar_t> a, Vec256<scalar_t> b) { | ||
return a / b; | ||
}); | ||
}); | ||
} | ||
|
||
void div_trunc_kernel(TensorIterator& iter) { | ||
auto dtype = iter.common_dtype(); | ||
if (isIntegralType(dtype, /*includeBool*/ false)) { | ||
// There's no SIMD integer division, so don't try to vectorize it. | ||
// TODO: if the divisor is a scalar, rewrite as multiplication by a constant. | ||
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "div_cpu", [&]() { | ||
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "div_trunc_cpu", [&]() { | ||
peterbell10 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t { | ||
TORCH_CHECK(b != 0, "ZeroDivisionError"); | ||
return a / b; | ||
}); | ||
}); | ||
} else { | ||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "div_cpu", [&]() { | ||
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, dtype, "div_trunc_cpu", [&]() { | ||
cpu_kernel_vec(iter, | ||
[](scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { | ||
return std::trunc(a / b); | ||
}, | ||
[](Vec256<scalar_t> a, Vec256<scalar_t> b) { | ||
return (a / b).trunc(); | ||
}); | ||
}); | ||
} | ||
} | ||
|
||
void div_floor_kernel(TensorIterator& iter) { | ||
const auto dtype = iter.common_dtype(); | ||
if (dtype == kByte) { | ||
// In the special case of unsigned integer division, floor division is | ||
// equivalent to truncation division (since the signs of the divisor and | ||
// dividend are always the same) | ||
return div_trunc_kernel(iter); | ||
} else if (isIntegralType(dtype, /*includeBool*/ false)) { | ||
// There's no SIMD integer division, so don't try to vectorize it. | ||
// TODO: if the divisor is a scalar, rewrite as multiplication by a constant. | ||
AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "div_floor_cpu", [&]() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is inconsistent between using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have removed all uses of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I realize the value is the same, just for readability the code might want to stick to either |
||
cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t { | ||
|
||
TORCH_CHECK(b != 0, "ZeroDivisionError"); | ||
if ((a < 0) != (b < 0)) { | ||
// Subtracts one from the results of truncation division if the | ||
// divisor and dividend have different sign(bit)s and the remainder of | ||
// the division is nonzero | ||
const auto quot = a / b; | ||
peterbell10 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const auto rem = a % b; | ||
return rem ? quot - 1 : quot; | ||
} | ||
|
||
return a / b; | ||
}); | ||
}); | ||
} else { | ||
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "div_floor_cpu", [&]() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same |
||
cpu_kernel_vec(iter, | ||
[](scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { | ||
return a / b; | ||
return std::floor(a / b); | ||
}, | ||
[](Vec256<scalar_t> a, Vec256<scalar_t> b) { | ||
return a / b; | ||
return (a / b).floor(); | ||
}); | ||
}); | ||
} | ||
|
@@ -838,7 +890,9 @@ REGISTER_DISPATCH(add_stub, &add_kernel); | |
REGISTER_DISPATCH(add_clamp_stub, &add_clamp_kernel); | ||
REGISTER_DISPATCH(sub_stub, &sub_kernel); | ||
REGISTER_DISPATCH(mul_stub, &mul_kernel); | ||
REGISTER_DISPATCH(div_stub, &div_kernel); | ||
REGISTER_DISPATCH(div_true_stub, &div_true_kernel); | ||
REGISTER_DISPATCH(div_trunc_stub, &div_trunc_kernel); | ||
REGISTER_DISPATCH(div_floor_stub, &div_floor_kernel); | ||
REGISTER_DISPATCH(remainder_stub, &remainder_kernel); | ||
REGISTER_DISPATCH(atan2_stub, &atan2_kernel); | ||
REGISTER_DISPATCH(bitwise_and_stub, &bitwise_and_kernel); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rzou would you take a look here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry I think you got the wrong user. Was that meant to be @zou3519?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was, thanks @peterbell10. Darn autocomplete!
cc @zou3519
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this lgtm!