diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 6167f889aeb76..91bf398561724 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -128,20 +128,6 @@ namespace at { namespace native { -inline ScalarType get_dtype_from_self( - const Tensor& self, - const optional& dtype, - bool promote_integers) { - if (dtype.has_value()) { - return dtype.value(); - } - ScalarType src_type = self.scalar_type(); - if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) { - return kLong; - } - return src_type; -} - } // namespace native namespace meta { @@ -1163,14 +1149,6 @@ std::vector gradient(const Tensor& self, IntArrayRef dim, int64_t edge_o // ALL REDUCE ################################################################# -inline ScalarType get_dtype_from_result(Tensor& result, optional dtype) { - TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor."); - if (dtype.has_value()) { - return dtype.value(); - } else { - return result.scalar_type(); - } -} TORCH_IMPL_FUNC(sum_out) (const Tensor& self, diff --git a/aten/src/ATen/native/ReduceOpsUtils.h b/aten/src/ATen/native/ReduceOpsUtils.h index 2b46eb683f1c9..8aa94c4b45ee7 100644 --- a/aten/src/ATen/native/ReduceOpsUtils.h +++ b/aten/src/ATen/native/ReduceOpsUtils.h @@ -320,6 +320,30 @@ static C10_UNUSED void zero_numel_tensor_resize(Tensor& result, Tensor& result_i at::native::resize_output(result_indices, sizes); } +inline ScalarType get_dtype_from_self( + const Tensor& self, + const c10::optional& dtype, + bool promote_integers) { + if (dtype.has_value()) { + return dtype.value(); + } + ScalarType src_type = self.scalar_type(); + if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) { + return kLong; + } + return src_type; +} + +inline ScalarType get_dtype_from_result(Tensor& result, c10::optional dtype) { + TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor."); + if (dtype.has_value()) { + return dtype.value(); + } else { + return result.scalar_type(); + } +} + + } // native namespace meta { diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 816bf5bcacbb3..88df3af523e8b 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -31,7 +31,8 @@ PROD, MEAN, COUNT_NONZERO, - TRACE + TRACE, + NANSUM, }; using namespace mps; @@ -247,6 +248,22 @@ void reduction_out_mps( castOutputTensor = [mpsGraph reductionSumWithTensor:bandPartWithTensor axes:@[@0, @1] name:nil]; + } else if (reduction_type == MPSReductionType::NANSUM) { + // Create a 0 tensor of the same shape as inputTensor + MPSGraphTensor* zeros = [mpsGraph constantWithScalar:0.0 + dataType:castInputTensor.dataType]; + // Find NaNs + MPSGraphTensor* nanMask = [mpsGraph isNaNWithTensor:castInputTensor + name:nil]; + // Replace NaNs with 0 + MPSGraphTensor* nanReplaced = [mpsGraph selectWithPredicateTensor:nanMask + truePredicateTensor:zeros + falsePredicateTensor:castInputTensor + name:nil]; + // Sum + castOutputTensor = [mpsGraph reductionSumWithTensor:nanReplaced + axes:wrappedAxes + name:nil]; } MPSGraphTensor* outputTensor = nil; @@ -289,6 +306,33 @@ void reduction_out_mps( reduction_out_mps(input_t, opt_dim, keepdim, dtype, output_t, MPSReductionType::SUM, "sum_out_mps"); } +Tensor& nansum_out_mps( + const Tensor& self, + OptionalIntArrayRef dim, + bool keepdim, + c10::optional opt_dtype, + Tensor& result) { + TORCH_CHECK(!c10::isComplexType(self.scalar_type()), "nansum does not support complex inputs"); + if (c10::isIntegralType(self.scalar_type(), true)){ + return at::sum_out(result, self, dim, keepdim, opt_dtype); + } + ScalarType dtype = get_dtype_from_result(result, opt_dtype); + const auto mask = make_dim_mask(dim, self.dim()); + resize_reduction_result(result, self, mask, keepdim, dtype); + reduction_out_mps(self, dim, keepdim, dtype, result, MPSReductionType::NANSUM, "nansum_out_mps"); + return result; +} + +Tensor nansum_mps( + const Tensor& self, + OptionalIntArrayRef dim, + bool keepdim, + c10::optional opt_dtype) { + ScalarType dtype = get_dtype_from_self(self, opt_dtype, true); + Tensor result = create_reduction_result(self, dim, keepdim, dtype); + return nansum_out_mps(self, dim, keepdim, dtype, result); +} + Tensor trace_mps_out(const Tensor& self) { Tensor output_t = at::native::empty_mps( {}, @@ -316,22 +360,6 @@ Tensor trace_mps_out(const Tensor& self) { reduction_out_mps(input_t, IntArrayRef(dims, 1), keepdim, dtype, output_t, MPSReductionType::PROD, "prod_out_mps"); } -// Taken from ReduceOps.cpp -inline ScalarType get_dtype_from_self( - const Tensor& self, - const c10::optional& dtype, - bool promote_integers) { - if (dtype.has_value()) { - return dtype.value(); - } - - ScalarType src_type = self.scalar_type(); - if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) { - return kLong; - } - return src_type; -} - TORCH_IMPL_FUNC(amax_out_mps)( const Tensor& input_t, IntArrayRef dim, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index bec46a06eec87..86dfe7a5a78bf 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5353,10 +5353,12 @@ variants: function, method dispatch: CPU, CUDA: nansum + MPS: nansum_mps - func: nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA: nansum_out + MPS: nansum_out_mps - func: sum_to_size(Tensor self, int[] size) -> Tensor variants: method diff --git a/test/test_mps.py b/test/test_mps.py index 3b3bf9ee7be13..12401f0e5b109 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -2279,6 +2279,51 @@ def test_binops_dtype_precedence(self): getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop) (torch.full(full_shape, val2, dtype=dtype2, device='cpu'))) + def test_nansum(self): + def helper(dtype, noncontiguous, dim): + zero_cpu = torch.zeros((), dtype=dtype) + + # Randomly scale the values + scale = random.randint(10, 100) + x_cpu: torch.Tensor = make_tensor( + (5, 5), dtype=dtype, device='cpu', + low=-scale, high=scale, noncontiguous=noncontiguous) + + if dtype.is_floating_point: + nan_mask_cpu = x_cpu < (0.2 * scale) + x_no_nan_cpu = torch.where(nan_mask_cpu, zero_cpu, x_cpu) + x_cpu[nan_mask_cpu] = np.nan + else: + x_no_nan_cpu = x_cpu + + x_mps = x_cpu.to('mps') + actual_out_mps = torch.empty(0, dtype=dtype, device='mps') + expect_out_cpu = torch.empty(0, dtype=dtype) + dim_kwargs = {"dim": dim} if dim is not None else {} + expect = torch.sum(x_no_nan_cpu, **dim_kwargs) + + actual_cpu = torch.nansum(x_cpu, **dim_kwargs) + # Sanity check on CPU + self.assertEqual(expect, actual_cpu) + + # Test MPS + actual_mps = torch.nansum(x_mps, **dim_kwargs) + # Test out= variant + torch.nansum(x_mps, out=actual_out_mps, **dim_kwargs) + torch.nansum(x_cpu, out=expect_out_cpu, **dim_kwargs) + self.assertEqual(expect, actual_mps) + self.assertEqual(expect_out_cpu, actual_out_mps) + + args = itertools.product( + (torch.float16, torch.float32, torch.int32, torch.int64), # dtype + (True, False), # noncontiguous + (0, 1, None), # dim + ) + + for dtype, noncontiguous, dim in args: + with self.subTest(dtype=dtype, noncontiguous=noncontiguous, dim=dim): + helper(dtype, noncontiguous, dim) + class TestLogical(TestCase): def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):