Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 0dim Tensor overload for _foreach_div #113688

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/native/ForeachOpsKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ FOREACH_BINARY_OP_LIST_ALPHA(lerp);

FOREACH_BINARY_OP_TENSOR_ALPHA(add);
FOREACH_BINARY_OP_TENSOR(mul);
FOREACH_BINARY_OP_TENSOR(div);

FOREACH_BINARY_OP_SCALAR(add);
FOREACH_BINARY_OP_SCALAR(sub);
Expand Down
16 changes: 16 additions & 0 deletions aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_foreach_add_native.h>
#include <ATen/ops/_foreach_div_native.h>
#include <ATen/ops/_foreach_mul_native.h>

#include <ATen/ops/empty_like_native.h>
Expand Down Expand Up @@ -98,6 +99,10 @@ void foreach_binary_op_(
#define FOREACH_BINARY_OP_SCALAR_TENSOR(FUNCTION, NAME, OP, DIVISION_OP) \
void foreach_tensor_##NAME##_tensor_kernel_cuda_( \
TensorList tensors, const Tensor& scalar) { \
if (scalar.device().type() == DeviceType::CPU) { \
return at::native::foreach_tensor_##NAME##_scalar_kernel_cuda_( \
tensors, scalar.item()); \
} \
check_foreach_api_restrictions(tensors); \
if (!(can_use_fast_route( \
ArrayRef<TensorList>{tensors}, {}, DIVISION_OP) && \
Expand All @@ -111,6 +116,10 @@ void foreach_binary_op_(
\
std::vector<Tensor> foreach_tensor_##NAME##_tensor_kernel_cuda( \
TensorList tensors, const Tensor& scalar) { \
if (scalar.device().type() == DeviceType::CPU) { \
return at::native::foreach_tensor_##NAME##_scalar_kernel_cuda( \
tensors, scalar.item()); \
} \
check_foreach_api_restrictions(tensors); \
if (!(can_use_fast_route( \
ArrayRef<TensorList>{tensors}, {}, DIVISION_OP) && \
Expand Down Expand Up @@ -187,4 +196,11 @@ FOREACH_BINARY_OP_SCALAR_TENSOR(
mul,
std::multiplies,
/* div_op */ false);

FOREACH_BINARY_OP_SCALAR_TENSOR(
all_types_complex_bool_half_bfloat16,
div,
std::divides,
/* div_op */ true);

} // namespace at::native
15 changes: 15 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10343,6 +10343,21 @@
CUDA: foreach_tensor_div_scalarlist_kernel_cuda_
autogen: _foreach_div.ScalarList_out

- func: _foreach_div.Tensor(Tensor[] self, Tensor other) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
CPU: foreach_tensor_div_tensor_kernel_slow
CUDA: foreach_tensor_div_tensor_kernel_cuda

- func: _foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
CPU: foreach_tensor_div_tensor_kernel_slow_
CUDA: foreach_tensor_div_tensor_kernel_cuda_
autogen: _foreach_div.Tensor_out

- func: _foreach_clamp_max.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
Expand Down
3 changes: 3 additions & 0 deletions test/expect/HasDecompTest.test_has_decomposition.expect
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,12 @@ aten::_foreach_div.Scalar
aten::_foreach_div.ScalarList
aten::_foreach_div.ScalarList_out
aten::_foreach_div.Scalar_out
aten::_foreach_div.Tensor
aten::_foreach_div.Tensor_out
aten::_foreach_div_.List
aten::_foreach_div_.Scalar
aten::_foreach_div_.ScalarList
aten::_foreach_div_.Tensor
aten::_foreach_erf
aten::_foreach_erf.out
aten::_foreach_erf_
Expand Down
34 changes: 24 additions & 10 deletions test/test_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, func):
# Some foreach functions don't have in-place implementations.
self.is_inplace = False if func is None else func.__name__.endswith('_')

def __call__(self, inputs, is_cuda, is_fastpath, **kwargs):
def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to expect_fastpath, which is what it really is and is less confusing

actual = None
zero_size = kwargs.pop("zero_size", False)
if (
Expand All @@ -60,7 +60,7 @@ def __call__(self, inputs, is_cuda, is_fastpath, **kwargs):
actual = self.func(*inputs, **kwargs)
keys = tuple([e.key for e in p.key_averages()])
mta_called = any("multi_tensor_apply_kernel" in k for k in keys)
assert mta_called == (is_fastpath and (not zero_size))
assert mta_called == (expect_fastpath and (not zero_size))
else:
actual = self.func(*inputs, **kwargs)
# note(mkozuki): inplace foreach functions are void functions.
Expand Down Expand Up @@ -127,9 +127,9 @@ def test_all_zero_size_tensors_do_not_launch_kernel(self, device, dtype, op):

for sample in op.sample_zero_size_inputs(device, dtype):
if not op.has_no_out_of_place:
wrapped_op((sample.input, *sample.args), is_cuda=self.is_cuda, is_fastpath=True, zero_size=True)
wrapped_op((sample.input, *sample.args), is_cuda=self.is_cuda, expect_fastpath=True, zero_size=True)
with InplaceForeachVersionBumpCheck(self, sample.input):
inplace_op((sample.input, *sample.args), is_cuda=self.is_cuda, is_fastpath=True, zero_size=True)
inplace_op((sample.input, *sample.args), is_cuda=self.is_cuda, expect_fastpath=True, zero_size=True)

@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
@ops(
Expand All @@ -150,7 +150,9 @@ def test_parity(self, device, dtype, op, noncontiguous, inplace):
for sample in op.sample_inputs(device, dtype, noncontiguous=noncontiguous):
ref_kwargs = sample.kwargs
kwargs = ref_kwargs.copy()
expect_fastpath = not (noncontiguous or sample.disable_fastpath)
# div promotes ints to floats, so we cannot go on the fastpath there
div_slowpath = dtype in integral_types_and(torch.bool) and op.name == '_foreach_div'
expect_fastpath = not (noncontiguous or sample.disable_fastpath or div_slowpath)
if op in foreach_pointwise_op_db:
values = kwargs.pop("values", None)
if values is not None:
Expand Down Expand Up @@ -242,7 +244,7 @@ def clone(arg):
(rhs_arg,) = transformed_sample.args
ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg)
sum(wrapped_op(
[rhs_arg, tensors], is_cuda=False, is_fastpath=False
[rhs_arg, tensors], is_cuda=False, expect_fastpath=False
)).mean().backward()
sum([ref.func(ref_rhs_arg, t) for t in ref_tensors]).mean().backward()
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
Expand Down Expand Up @@ -628,7 +630,7 @@ def test_foreach_l2_large_value_input(self, device, dtype, op):
# make sure that the min. of squared L2 norm value per tensor is greater than the max value of `dtype`.
self.assertTrue(scaler * scaler * N > max_value)
fn, ref_fn, *_ = self._get_funcs(op)
actual = fn(inputs, is_cuda=True, is_fastpath=True, ord=ord, zero_size=False)
actual = fn(inputs, is_cuda=True, expect_fastpath=True, ord=ord, zero_size=False)
expect = ref_fn(inputs, ord=ord)

if dtype == torch.float16:
Expand Down Expand Up @@ -692,7 +694,7 @@ def test_outplace_with_invalid_grads(self, device, dtype, op):
self.assertTrue(all(t.requires_grad for t in sample.input))
if func.func in foreach_pointwise_op_db:
sample.kwargs.pop("values", None)
(out1, out2) = func([sample.input, *sample.args], is_cuda=False, is_fastpath=False, **sample.kwargs)
(out1, out2) = func([sample.input, *sample.args], is_cuda=False, expect_fastpath=False, **sample.kwargs)
out1.backward(torch.ones_like(out1))
self.assertIsNotNone(sample.input[0].grad)
self.assertIsNone(sample.input[1].grad)
Expand All @@ -710,7 +712,7 @@ def get_ref(func, sample):
class Foo:
pass

out = func((sample.input, *sample.args), is_cuda=False, is_fastpath=False, **sample.kwargs)
out = func((sample.input, *sample.args), is_cuda=False, expect_fastpath=False, **sample.kwargs)
foo = Foo()
meta_dict = out[0].grad_fn.metadata
meta_dict[0] = foo
Expand Down Expand Up @@ -772,12 +774,24 @@ def test_tensors_grouping(self):
self.assertEqual(l3[i], list3[index])
self.assertEqual(num_tensors_seen, 2 * num_tensors_per_list)

@onlyCUDA
def test_0dim_tensor_overload_cpu_ok(self):
tensors = [torch.ones((), device="cuda", dtype=torch.float32) for _ in range(2)]
scalar_cpu_tensor = torch.tensor(4.0, device="cpu")

# For mul and div, the scalar is allowed to be on CPU too
actual = torch._foreach_mul(tensors, scalar_cpu_tensor)
self.assertEqual(actual, [t.mul(scalar_cpu_tensor) for t in tensors])
actual = torch._foreach_div(tensors, scalar_cpu_tensor)
self.assertEqual(actual, [t.div(scalar_cpu_tensor) for t in tensors])


@onlyCUDA
def test_0dim_tensor_overload_exception(self):
# check exceptions of fast path
tensors = [make_tensor((2, 2), dtype=torch.float, device="cuda") for _ in range(2)]
with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be on"):
torch._foreach_mul(tensors, torch.tensor(1.0, device="cpu"))
torch._foreach_add(tensors, torch.tensor(1.0, device="cpu"), alpha=1.0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works now, but made me realize I didn't add a case for _foreach_add when I added the overload. Adding that now.


tensors = [make_tensor((2, 2), dtype=torch.float, device=d) for d in ("cpu", "cuda")]
with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be 0 dim but"):
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9479,7 +9479,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
"div",
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
sample_inputs_func=foreach_inputs_sample_func(2, True, True),
sample_inputs_func=foreach_inputs_sample_func(2, True, True, True),
skips=(
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
Expand Down
1 change: 1 addition & 0 deletions torchgen/api/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def is_foreach_func(f: NativeFunction) -> bool:
_foreach_with_tensor_overload = {
"_foreach_add.Tensor",
"_foreach_mul.Tensor",
"_foreach_div.Tensor",
}


Expand Down
Loading