Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 0dim Tensor overload for _foreach_div #113688

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/native/ForeachOpsKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ FOREACH_BINARY_OP_LIST_ALPHA(lerp);

FOREACH_BINARY_OP_TENSOR_ALPHA(add);
FOREACH_BINARY_OP_TENSOR(mul);
FOREACH_BINARY_OP_TENSOR(div);

FOREACH_BINARY_OP_SCALAR(add);
FOREACH_BINARY_OP_SCALAR(sub);
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_foreach_add_native.h>
#include <ATen/ops/_foreach_div_native.h>
#include <ATen/ops/_foreach_mul_native.h>

#include <ATen/ops/empty_like_native.h>
Expand Down Expand Up @@ -187,4 +188,11 @@ FOREACH_BINARY_OP_SCALAR_TENSOR(
mul,
std::multiplies,
/* div_op */ false);

FOREACH_BINARY_OP_SCALAR_TENSOR(
all_types_complex_bool_half_bfloat16,
div,
std::divides,
/* div_op */ true);

} // namespace at::native
15 changes: 15 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10343,6 +10343,21 @@
CUDA: foreach_tensor_div_scalarlist_kernel_cuda_
autogen: _foreach_div.ScalarList_out

- func: _foreach_div.Tensor(Tensor[] self, Tensor other) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
CPU: foreach_tensor_div_tensor_kernel_slow
CUDA: foreach_tensor_div_tensor_kernel_cuda

- func: _foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
CPU: foreach_tensor_div_tensor_kernel_slow_
CUDA: foreach_tensor_div_tensor_kernel_cuda_
autogen: _foreach_div.Tensor_out

- func: _foreach_clamp_max.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
Expand Down
3 changes: 3 additions & 0 deletions test/expect/HasDecompTest.test_has_decomposition.expect
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,12 @@ aten::_foreach_div.Scalar
aten::_foreach_div.ScalarList
aten::_foreach_div.ScalarList_out
aten::_foreach_div.Scalar_out
aten::_foreach_div.Tensor
aten::_foreach_div.Tensor_out
aten::_foreach_div_.List
aten::_foreach_div_.Scalar
aten::_foreach_div_.ScalarList
aten::_foreach_div_.Tensor
aten::_foreach_erf
aten::_foreach_erf.out
aten::_foreach_erf_
Expand Down
8 changes: 5 additions & 3 deletions test/test_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, func):
# Some foreach functions don't have in-place implementations.
self.is_inplace = False if func is None else func.__name__.endswith('_')

def __call__(self, inputs, is_cuda, is_fastpath, **kwargs):
def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to expect_fastpath, which is what it really is and is less confusing

actual = None
zero_size = kwargs.pop("zero_size", False)
if (
Expand All @@ -60,7 +60,7 @@ def __call__(self, inputs, is_cuda, is_fastpath, **kwargs):
actual = self.func(*inputs, **kwargs)
keys = tuple([e.key for e in p.key_averages()])
mta_called = any("multi_tensor_apply_kernel" in k for k in keys)
assert mta_called == (is_fastpath and (not zero_size))
assert mta_called == (expect_fastpath and (not zero_size))
else:
actual = self.func(*inputs, **kwargs)
# note(mkozuki): inplace foreach functions are void functions.
Expand Down Expand Up @@ -150,7 +150,9 @@ def test_parity(self, device, dtype, op, noncontiguous, inplace):
for sample in op.sample_inputs(device, dtype, noncontiguous=noncontiguous):
ref_kwargs = sample.kwargs
kwargs = ref_kwargs.copy()
expect_fastpath = not (noncontiguous or sample.disable_fastpath)
# div promotes ints to floats, so we cannot go on the fastpath there
div_slowpath = dtype in integral_types_and(torch.bool) and op.name == '_foreach_div'
expect_fastpath = not (noncontiguous or sample.disable_fastpath or div_slowpath)
if op in foreach_pointwise_op_db:
values = kwargs.pop("values", None)
if values is not None:
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9479,7 +9479,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
"div",
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
sample_inputs_func=foreach_inputs_sample_func(2, True, True),
sample_inputs_func=foreach_inputs_sample_func(2, True, True, True),
skips=(
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
Expand Down
1 change: 1 addition & 0 deletions torchgen/api/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def is_foreach_func(f: NativeFunction) -> bool:
_foreach_with_tensor_overload = {
"_foreach_add.Tensor",
"_foreach_mul.Tensor",
"_foreach_div.Tensor",
}


Expand Down