From 26b5e27ace804169ee2b2d2bf9933b0cf87bd0cb Mon Sep 17 00:00:00 2001 From: CaoE Date: Sun, 5 Nov 2023 12:31:38 +0000 Subject: [PATCH] Add Half support for cummax, cummin, cumprod, logcumsumexp, and prod on CPU (#112132) Add Half support for cummax, cummin, cumprod, logcumsumexp, and prod on CPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112132 Approved by: https://github.com/cpuhrsch --- aten/src/ATen/native/ReduceOps.cpp | 4 ++-- aten/src/ATen/native/cpu/ReduceOpsKernel.cpp | 6 +++--- test/test_mps.py | 5 ++++- test/test_reductions.py | 12 ++++++++++++ test/test_sparse.py | 11 ++++++++++- .../_internal/common_methods_invocations.py | 14 +++++--------- .../_internal/opinfo/definitions/_masked.py | 15 ++++++++------- 7 files changed, 44 insertions(+), 23 deletions(-) diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 04d0e0c3b51ff..7a47490c6745c 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -794,7 +794,7 @@ void cummax_cummin_helper(const T1* self_data, T1* values_data, T2* indices_data } void cummax_helper_cpu(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) { - AT_DISPATCH_ALL_TYPES_AND2(kBool, kBFloat16, + AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, self.scalar_type(), "cummax_cpu", [&] { at::native::tensor_dim_apply3(self, values, indices, dim, cummax_cummin_helper>); @@ -829,7 +829,7 @@ std::tuple cummax(const Tensor& self, int64_t dim) { } void cummin_helper_cpu(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) { - AT_DISPATCH_ALL_TYPES_AND2(kBool, kBFloat16, + AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, self.scalar_type(), "cummin_cpu", [&] { at::native::tensor_dim_apply3(self, values, indices, dim, cummax_cummin_helper>); diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index f4f73f405620c..405fda4d58230 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -100,7 +100,7 @@ static void cumprod_cpu_kernel(const Tensor& result, const Tensor& self, int64_t auto wrap_dim = maybe_wrap_dim(dim, self.dim()); int64_t self_dim_size = ensure_nonempty_size(self, wrap_dim); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, self.scalar_type(), "cumprod_out_cpu", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, self.scalar_type(), "cumprod_out_cpu", [&] { cpu_cum_base_kernel(result, self, wrap_dim, [&] ( scalar_t* result_data, auto result_dim_stride, const scalar_t* self_data, auto self_dim_stride, scalar_t init_val) { @@ -119,7 +119,7 @@ static void logcumsumexp_cpu_kernel(Tensor& result, const Tensor& self, int64_t auto wrap_dim = maybe_wrap_dim(dim, self.dim()); int64_t self_dim_size = ensure_nonempty_size(self, wrap_dim); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, self.scalar_type(), "logcumsumexp_out_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(), "logcumsumexp_out_cpu", [&] { cpu_cum_base_kernel(result, self, wrap_dim, [&] ( scalar_t* result_data, auto result_dim_stride, const scalar_t* self_data, auto self_dim_stride, scalar_t init_val) { @@ -176,7 +176,7 @@ static void prod_kernel_impl(TensorIterator& iter) { // NOLINTNEXTLINE(bugprone-argument-comment) /*identity=*/1); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "prod_out_cpu", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "prod_out_cpu", [&] { binary_kernel_reduce_vec( iter, [=](scalar_t a, scalar_t b) diff --git a/test/test_mps.py b/test/test_mps.py index b359a4c839b43..817e11eff493d 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -164,7 +164,9 @@ def mps_ops_grad_modifier(ops): '__rpow__': [torch.float32], # See https://github.com/pytorch/pytorch/issues/106112 for more information - 'cumprod': [torch.float32], + 'cumprod': [torch.float32, torch.float16], + # See https://github.com/pytorch/pytorch/issues/109166 for more information + 'masked.cumprod': [torch.float16], } SKIPLIST_GRAD = { @@ -10943,6 +10945,7 @@ class TestConsistency(TestCaseMPS): 'nn.functional.kl_div', 'nn.functional.softmin', 'cross', 'linalg.cross', + 'prod', 'masked.prod', # for macOS 12 'masked.normalize', 'masked.sum', 'masked.var', diff --git a/test/test_reductions.py b/test/test_reductions.py index f75526766fb0a..272769808d6fe 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -1435,6 +1435,18 @@ def test_prod(self, device, dtype): torch.prod(x, 1, out=res2) self.assertEqual(res1, res2) + @onlyCPU + @dtypes(torch.float16, torch.bfloat16) + def test_prod_lowp(self, device, dtype): + x = torch.rand(100, 100, dtype=dtype, device=device) + x_ref = x.float() + res1 = torch.prod(x, 1) + res2 = torch.prod(x_ref, 1) + self.assertEqual(res1, res2.to(dtype=dtype)) + res1 = torch.prod(x, 0) + res2 = torch.prod(x_ref, 0) + self.assertEqual(res1, res2.to(dtype=dtype)) + def test_prod_bool(self, device): vals = [[True, True], [True, False], [False, False], []] for val in vals: diff --git a/test/test_sparse.py b/test/test_sparse.py index e335fa94f2225..fad4db97e4d2c 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -4225,6 +4225,10 @@ def fn(x): class TestSparseMaskedReductions(TestCase): exact_dtype = True + fp16_low_precision_list = { + 'masked.prod', + } + @ops(sparse_masked_reduction_ops) def test_future_empty_dim(self, device, dtype, op): """Currently, `dim=()` in reductions operations means "reduce over @@ -4263,7 +4267,12 @@ def test_future_empty_dim(self, device, dtype, op): self.assertEqual(actual.layout, torch.sparse_coo) expected = op(t, *sample_input.args, **sample_input_kwargs).to_sparse() - self.assertEqual(actual, expected) + atol = None + rtol = None + if op.name in self.fp16_low_precision_list and dtype == torch.half: + atol = 1e-5 + rtol = 2e-3 + self.assertEqual(actual, expected, atol=atol, rtol=rtol) class TestSparseMeta(TestCase): diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 1667fa0c3ab44..bf5fcdb42f12e 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -10821,8 +10821,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): ), sample_inputs_func=sample_inputs_cumulative_ops), OpInfo('cumprod', - dtypes=all_types_and_complex_and(torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=( @@ -10833,8 +10832,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): sample_inputs_func=sample_inputs_cumprod, gradcheck_fast_mode=False), OpInfo('cummax', - dtypes=all_types_and(torch.bool, torch.bfloat16), - dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False), supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -10842,8 +10840,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): ), gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), OpInfo('cummin', - dtypes=all_types_and(torch.bool, torch.bfloat16), - dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False), supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -17294,8 +17291,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): ) ), OpInfo('logcumsumexp', - dtypes=floating_and_complex_types_and(torch.bfloat16), - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), backward_dtypes=floating_and_complex_types_and(torch.bfloat16), backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16), supports_forward_ad=True, @@ -18371,7 +18367,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_fwgrad_bwgrad=True, promotes_int_to_int64=True, gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, - dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), sample_inputs_func=sample_inputs_prod, ref=prod_numpy, diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py index d9b44a3d1840a..98fef72aae250 100644 --- a/torch/testing/_internal/opinfo/definitions/_masked.py +++ b/torch/testing/_internal/opinfo/definitions/_masked.py @@ -504,11 +504,7 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar supports_sparse=True, supports_sparse_csr=True, promotes_int_to_int64=True, - # FIXME: "prod_cpu" not implemented for 'Half' - dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and( - torch.bool, torch.float16, torch.bfloat16 - ), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), skips=( DecorateInfo( unittest.expectedFailure, @@ -554,6 +550,12 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar "TestReductions", "test_ref_small_input", ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1.5e-03)}), + "TestMasked", + "test_mask_layout", + device_type="cpu", + ), ], sample_inputs_func=sample_inputs_masked_reduction, sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction, @@ -585,8 +587,7 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar ), OpInfo( "masked.cumprod", - dtypes=all_types_and_complex_and(torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), method_variant=None, # Runs very slowly on slow gradcheck - alternatively reduce input sizes gradcheck_fast_mode=True,