From 50dfbeeed71380b67e6cba6a5477177733dc15bd Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Tue, 14 Nov 2023 13:16:02 -0800 Subject: [PATCH 1/5] Add 0dim Tensor overload for _foreach_div [ghstack-poisoned] --- aten/src/ATen/native/ForeachOpsKernels.cpp | 1 + .../native/cuda/ForeachBinaryOpScalarTensor.cu | 8 ++++++++ aten/src/ATen/native/native_functions.yaml | 15 +++++++++++++++ .../HasDecompTest.test_has_decomposition.expect | 3 +++ test/test_foreach.py | 2 ++ .../_internal/common_methods_invocations.py | 2 +- torchgen/api/autograd.py | 1 + 7 files changed, 31 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/ForeachOpsKernels.cpp b/aten/src/ATen/native/ForeachOpsKernels.cpp index b649673681193..a66209cbe807b 100644 --- a/aten/src/ATen/native/ForeachOpsKernels.cpp +++ b/aten/src/ATen/native/ForeachOpsKernels.cpp @@ -334,6 +334,7 @@ FOREACH_BINARY_OP_LIST_ALPHA(lerp); FOREACH_BINARY_OP_TENSOR_ALPHA(add); FOREACH_BINARY_OP_TENSOR(mul); +FOREACH_BINARY_OP_TENSOR(div); FOREACH_BINARY_OP_SCALAR(add); FOREACH_BINARY_OP_SCALAR(sub); diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu index c9783ae551cd4..e9cc54d7c5c9a 100644 --- a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu +++ b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu @@ -9,6 +9,7 @@ #include #else #include +#include #include #include @@ -187,4 +188,11 @@ FOREACH_BINARY_OP_SCALAR_TENSOR( mul, std::multiplies, /* div_op */ false); + +FOREACH_BINARY_OP_SCALAR_TENSOR( + all_types_complex_bool_half_bfloat16, + div, + std::divides, + /* div_op */ true); + } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 504c6cb772571..2a97ecd0eb6d6 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -10343,6 +10343,21 @@ CUDA: foreach_tensor_div_scalarlist_kernel_cuda_ autogen: _foreach_div.ScalarList_out +- func: _foreach_div.Tensor(Tensor[] self, Tensor other) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CPU: foreach_tensor_div_tensor_kernel_slow + CUDA: foreach_tensor_div_tensor_kernel_cuda + +- func: _foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CPU: foreach_tensor_div_tensor_kernel_slow_ + CUDA: foreach_tensor_div_tensor_kernel_cuda_ + autogen: _foreach_div.Tensor_out + - func: _foreach_clamp_max.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 74e3cb0212135..b43759643a09e 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -200,9 +200,12 @@ aten::_foreach_div.Scalar aten::_foreach_div.ScalarList aten::_foreach_div.ScalarList_out aten::_foreach_div.Scalar_out +aten::_foreach_div.Tensor +aten::_foreach_div.Tensor_out aten::_foreach_div_.List aten::_foreach_div_.Scalar aten::_foreach_div_.ScalarList +aten::_foreach_div_.Tensor aten::_foreach_erf aten::_foreach_erf.out aten::_foreach_erf_ diff --git a/test/test_foreach.py b/test/test_foreach.py index f44627dabf4d6..3ea37d3f57b4e 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -60,6 +60,8 @@ def __call__(self, inputs, is_cuda, is_fastpath, **kwargs): actual = self.func(*inputs, **kwargs) keys = tuple([e.key for e in p.key_averages()]) mta_called = any("multi_tensor_apply_kernel" in k for k in keys) + print(any("multi_tensor_apply_kernel" in k for k in keys)) + print((is_fastpath and (not zero_size))) assert mta_called == (is_fastpath and (not zero_size)) else: actual = self.func(*inputs, **kwargs) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 028596fce5bf6..d26b952f271cb 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -9479,7 +9479,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "div", dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), - sample_inputs_func=foreach_inputs_sample_func(2, True, True), + sample_inputs_func=foreach_inputs_sample_func(2, True, True, True), skips=( DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"), diff --git a/torchgen/api/autograd.py b/torchgen/api/autograd.py index 39fcf81855f35..1a55211b99902 100644 --- a/torchgen/api/autograd.py +++ b/torchgen/api/autograd.py @@ -319,6 +319,7 @@ def is_foreach_func(f: NativeFunction) -> bool: _foreach_with_tensor_overload = { "_foreach_add.Tensor", "_foreach_mul.Tensor", + "_foreach_div.Tensor", } From 2e6bb1aa67bd1535df357ee0d527abd3492bebcc Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Tue, 14 Nov 2023 14:05:48 -0800 Subject: [PATCH 2/5] Update on "Add 0dim Tensor overload for _foreach_div" [ghstack-poisoned] --- test/test_foreach.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_foreach.py b/test/test_foreach.py index 3ea37d3f57b4e..b4e6c2a0d52b8 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -48,7 +48,7 @@ def __init__(self, func): # Some foreach functions don't have in-place implementations. self.is_inplace = False if func is None else func.__name__.endswith('_') - def __call__(self, inputs, is_cuda, is_fastpath, **kwargs): + def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs): actual = None zero_size = kwargs.pop("zero_size", False) if ( @@ -60,9 +60,7 @@ def __call__(self, inputs, is_cuda, is_fastpath, **kwargs): actual = self.func(*inputs, **kwargs) keys = tuple([e.key for e in p.key_averages()]) mta_called = any("multi_tensor_apply_kernel" in k for k in keys) - print(any("multi_tensor_apply_kernel" in k for k in keys)) - print((is_fastpath and (not zero_size))) - assert mta_called == (is_fastpath and (not zero_size)) + assert mta_called == (expect_fastpath and (not zero_size)) else: actual = self.func(*inputs, **kwargs) # note(mkozuki): inplace foreach functions are void functions. @@ -152,7 +150,9 @@ def test_parity(self, device, dtype, op, noncontiguous, inplace): for sample in op.sample_inputs(device, dtype, noncontiguous=noncontiguous): ref_kwargs = sample.kwargs kwargs = ref_kwargs.copy() - expect_fastpath = not (noncontiguous or sample.disable_fastpath) + # div promotes ints to floats, so we cannot go on the fastpath there + div_slowpath = dtype in integral_types_and(torch.bool) and op.name == '_foreach_div' + expect_fastpath = not (noncontiguous or sample.disable_fastpath or div_slowpath) if op in foreach_pointwise_op_db: values = kwargs.pop("values", None) if values is not None: From 97fd79fe7aaee98f66581bbe4d105b293ab00c5c Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Wed, 15 Nov 2023 07:29:48 -0800 Subject: [PATCH 3/5] Update on "Add 0dim Tensor overload for _foreach_div" cc crcrpar [ghstack-poisoned] --- test/test_foreach.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_foreach.py b/test/test_foreach.py index b4e6c2a0d52b8..1fb3880b22067 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -127,9 +127,9 @@ def test_all_zero_size_tensors_do_not_launch_kernel(self, device, dtype, op): for sample in op.sample_zero_size_inputs(device, dtype): if not op.has_no_out_of_place: - wrapped_op((sample.input, *sample.args), is_cuda=self.is_cuda, is_fastpath=True, zero_size=True) + wrapped_op((sample.input, *sample.args), is_cuda=self.is_cuda, expect_fastpath=True, zero_size=True) with InplaceForeachVersionBumpCheck(self, sample.input): - inplace_op((sample.input, *sample.args), is_cuda=self.is_cuda, is_fastpath=True, zero_size=True) + inplace_op((sample.input, *sample.args), is_cuda=self.is_cuda, expect_fastpath=True, zero_size=True) @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7") @ops( @@ -244,7 +244,7 @@ def clone(arg): (rhs_arg,) = transformed_sample.args ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg) sum(wrapped_op( - [rhs_arg, tensors], is_cuda=False, is_fastpath=False + [rhs_arg, tensors], is_cuda=False, expect_fastpath=False )).mean().backward() sum([ref.func(ref_rhs_arg, t) for t in ref_tensors]).mean().backward() self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors]) @@ -630,7 +630,7 @@ def test_foreach_l2_large_value_input(self, device, dtype, op): # make sure that the min. of squared L2 norm value per tensor is greater than the max value of `dtype`. self.assertTrue(scaler * scaler * N > max_value) fn, ref_fn, *_ = self._get_funcs(op) - actual = fn(inputs, is_cuda=True, is_fastpath=True, ord=ord, zero_size=False) + actual = fn(inputs, is_cuda=True, expect_fastpath=True, ord=ord, zero_size=False) expect = ref_fn(inputs, ord=ord) if dtype == torch.float16: @@ -694,7 +694,7 @@ def test_outplace_with_invalid_grads(self, device, dtype, op): self.assertTrue(all(t.requires_grad for t in sample.input)) if func.func in foreach_pointwise_op_db: sample.kwargs.pop("values", None) - (out1, out2) = func([sample.input, *sample.args], is_cuda=False, is_fastpath=False, **sample.kwargs) + (out1, out2) = func([sample.input, *sample.args], is_cuda=False, expect_fastpath=False, **sample.kwargs) out1.backward(torch.ones_like(out1)) self.assertIsNotNone(sample.input[0].grad) self.assertIsNone(sample.input[1].grad) @@ -712,7 +712,7 @@ def get_ref(func, sample): class Foo: pass - out = func((sample.input, *sample.args), is_cuda=False, is_fastpath=False, **sample.kwargs) + out = func((sample.input, *sample.args), is_cuda=False, expect_fastpath=False, **sample.kwargs) foo = Foo() meta_dict = out[0].grad_fn.metadata meta_dict[0] = foo From f46ca558e2e182a66083a334b72a16d938bedaeb Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Wed, 15 Nov 2023 08:59:19 -0800 Subject: [PATCH 4/5] Update on "Add 0dim Tensor overload for _foreach_div" cc crcrpar [ghstack-poisoned] --- aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu index e9cc54d7c5c9a..ad5eeee5ebec4 100644 --- a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu +++ b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu @@ -99,6 +99,10 @@ void foreach_binary_op_( #define FOREACH_BINARY_OP_SCALAR_TENSOR(FUNCTION, NAME, OP, DIVISION_OP) \ void foreach_tensor_##NAME##_tensor_kernel_cuda_( \ TensorList tensors, const Tensor& scalar) { \ + if (scalar.device().type() == DeviceType::CPU) { \ + return at::native::foreach_tensor_##NAME##_scalar_kernel_cuda_( \ + tensors, scalar.item()); \ + } \ check_foreach_api_restrictions(tensors); \ if (!(can_use_fast_route( \ ArrayRef{tensors}, {}, DIVISION_OP) && \ @@ -112,6 +116,10 @@ void foreach_binary_op_( \ std::vector foreach_tensor_##NAME##_tensor_kernel_cuda( \ TensorList tensors, const Tensor& scalar) { \ + if (scalar.device().type() == DeviceType::CPU) { \ + return at::native::foreach_tensor_##NAME##_scalar_kernel_cuda( \ + tensors, scalar.item()); \ + } \ check_foreach_api_restrictions(tensors); \ if (!(can_use_fast_route( \ ArrayRef{tensors}, {}, DIVISION_OP) && \ From ed393570850abd1689f9c6e1fd8b12630fa7d584 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Wed, 15 Nov 2023 10:09:18 -0800 Subject: [PATCH 5/5] Update on "Add 0dim Tensor overload for _foreach_div" This PR is ALMOST basically just following the steps from #106677 EXCEPT! We do add one feature. Similar to fused_adam(w), for the CUDA dispatches: when the scalar tensor is on CPU, we .item and redispatch to the normal scalar overload. Otherwise, the cuda kernel will complain about mismatch in devices between the scalar and the tensors. Why do we add this feature? Our optimizers want to allow lr as a tensor, and lr could be a CPU tensor. lr is used with foreach_div_ in Adam, so our CI will break otherwise. After this PR, `_foreach_mul` and `_foreach_div` will accept either a CPU or a GPU tensor for the scalar tensor (vs only a GPU tensor). They join the ranks of `fused_adam(w)` in this characteristic. I did not yet do the same thing for foreach_add (the only other foreach op with a .Tensor overload) because there is no use case and will be more involved. cc crcrpar [ghstack-poisoned] --- test/test_foreach.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/test/test_foreach.py b/test/test_foreach.py index 1fb3880b22067..52d91f681c301 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -774,12 +774,24 @@ def test_tensors_grouping(self): self.assertEqual(l3[i], list3[index]) self.assertEqual(num_tensors_seen, 2 * num_tensors_per_list) + @onlyCUDA + def test_0dim_tensor_overload_cpu_ok(self): + tensors = [torch.ones((), device="cuda", dtype=torch.float32) for _ in range(2)] + scalar_cpu_tensor = torch.tensor(4.0, device="cpu") + + # For mul and div, the scalar is allowed to be on CPU too + actual = torch._foreach_mul(tensors, scalar_cpu_tensor) + self.assertEqual(actual, [t.mul(scalar_cpu_tensor) for t in tensors]) + actual = torch._foreach_div(tensors, scalar_cpu_tensor) + self.assertEqual(actual, [t.div(scalar_cpu_tensor) for t in tensors]) + + @onlyCUDA def test_0dim_tensor_overload_exception(self): # check exceptions of fast path tensors = [make_tensor((2, 2), dtype=torch.float, device="cuda") for _ in range(2)] with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be on"): - torch._foreach_mul(tensors, torch.tensor(1.0, device="cpu")) + torch._foreach_add(tensors, torch.tensor(1.0, device="cpu"), alpha=1.0) tensors = [make_tensor((2, 2), dtype=torch.float, device=d) for d in ("cpu", "cuda")] with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be 0 dim but"):