Skip to content

Commit

Permalink
Add cumprod support for device mps (#104688)
Browse files Browse the repository at this point in the history
Related to #77764

Add support for the cumprod operation (which in turn allows its gradient). This also allows us to compute the gradient of prod since it was blocked behind cumprod in the case where exactly one element of the tensor was 0.

Pull Request resolved: #104688
Approved by: https://github.com/kulinseth
  • Loading branch information
Mr4k authored and pytorchmergebot committed Aug 1, 2023
1 parent fadd085 commit 97e5055
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 25 deletions.
71 changes: 51 additions & 20 deletions aten/src/ATen/native/mps/operations/UnaryOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <ATen/ops/ceil_native.h>
#include <ATen/ops/cos_native.h>
#include <ATen/ops/cosh_native.h>
#include <ATen/ops/cumprod_native.h>
#include <ATen/ops/cumsum_native.h>
#include <ATen/ops/erf_native.h>
#include <ATen/ops/exp2_native.h>
Expand Down Expand Up @@ -46,6 +47,12 @@
#endif

namespace at::native {

enum class MPSCumulativeOpType : uint8_t {
CUMSUM = 0,
CUMPROD = 1,
};

namespace mps {

typedef MPSGraphTensor* (^UnaryOpBlock)(MPSGraph*, MPSGraphTensor*);
Expand Down Expand Up @@ -375,8 +382,12 @@ Tensor logit_mps(const Tensor& self, c10::optional<double> eps) {
}
}

TORCH_IMPL_FUNC(cumsum_out_mps)
(const Tensor& self, int64_t dim, c10::optional<ScalarType> dtype, const Tensor& result) {
void cumulative_op_impl(const Tensor& self,
int64_t dim,
c10::optional<ScalarType> dtype,
const Tensor& result,
MPSCumulativeOpType cumulativeOpType,
const std::string& op_name) {
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
auto nDims = self.dim();
auto wrapped_dim = maybe_wrap_dim(dim, nDims);
Expand All @@ -389,34 +400,54 @@ Tensor logit_mps(const Tensor& self, c10::optional<double> eps) {
dim,
")");
if (!is_macos_13_or_newer()) {
TORCH_WARN_ONCE("torch.cumsum supported by MPS on MacOS 13+, please upgrade");
auto cpu_result = self.to(at::Device(kCPU)).cumsum(dim, dtype);
TORCH_WARN_ONCE(op_name, " supported by MPS on MacOS 13+, please upgrade");
Tensor cpu_result;
if (cumulativeOpType == MPSCumulativeOpType::CUMSUM) {
cpu_result = self.to(at::Device(kCPU)).cumsum(dim, dtype);
} else if (cumulativeOpType == MPSCumulativeOpType::CUMPROD) {
cpu_result = self.to(at::Device(kCPU)).cumprod(dim, dtype);
}
at::_copy_from_and_resize(cpu_result, result);
return;
}
auto input = dtype.has_value() ? self.to(dtype.value()) : self;

// issue #103810551: cumsum is horribly broken for int8, int16 and as chances for overflow is pretty high, cast to
// issue #103810551: cumsum / cumprod are broken for int8, int16 and as chances for overflow are pretty high, cast to
// int32 fixed in macOS 13.3
bool castInputData = (isIntegralType(input.scalar_type(), false) && input.scalar_type() != ScalarType::Int &&
input.scalar_type() != ScalarType::Long);

TORCH_CHECK(macOS13_3_plus || input.scalar_type() != ScalarType::Long,
"MPS does not support cumsum op with int64 input. Support has been added in macOS 13.3");

mps::unary_op(input,
result,
"cumsum_out_mp" + std::to_string(dim),
^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
if (castInputData) {
inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, ScalarType::Int);
}
auto rc = [mpsGraph cumulativeSumWithTensor:inputTensor axis:dim name:nil];
if ((mps::getMPSDataType(result) != [rc dataType]) || castInputData) {
return mps::castMPSTensor(mpsGraph, rc, result.scalar_type());
}
return rc;
});
"MPS does not support ",
op_name,
" op with int64 input. Support has been added in macOS 13.3");

mps::unary_op(
input, result, op_name + std::to_string(dim), ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
if (castInputData) {
inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, ScalarType::Int);
}
MPSGraphTensor* rc;
if (cumulativeOpType == MPSCumulativeOpType::CUMSUM) {
rc = [mpsGraph cumulativeSumWithTensor:inputTensor axis:dim name:nil];
} else if (cumulativeOpType == MPSCumulativeOpType::CUMPROD) {
rc = [mpsGraph cumulativeProductWithTensor:inputTensor axis:dim name:nil];
}
if ((mps::getMPSDataType(result) != [rc dataType]) || castInputData) {
return mps::castMPSTensor(mpsGraph, rc, result.scalar_type());
}
return rc;
});
}

TORCH_IMPL_FUNC(cumsum_out_mps)
(const Tensor& self, int64_t dim, c10::optional<ScalarType> dtype, const Tensor& result) {
return cumulative_op_impl(self, dim, dtype, result, MPSCumulativeOpType::CUMSUM, "cumsum_out_mps");
}

TORCH_IMPL_FUNC(cumprod_out_mps)
(const Tensor& self, int64_t dim, c10::optional<ScalarType> dtype, const Tensor& result) {
return cumulative_op_impl(self, dim, dtype, result, MPSCumulativeOpType::CUMPROD, "cumprod_out_mps");
}

} // namespace at::native
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1899,6 +1899,7 @@
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: cumprod_out
MPS: cumprod_out_mps

- func: cumprod.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # TensorIterator
Expand Down
55 changes: 50 additions & 5 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def mps_ops_grad_modifier(ops):

# Unimplemented ops
'__getitem__': [torch.float16],
'prod': [torch.float32], # The operator 'aten::cumprod.out'
'sgn': [torch.float16, torch.float32],
'_segment_reduce': [torch.float16, torch.float32],
'unfold_copy': [torch.float16, torch.float32], # unfold_backward is not implemented
Expand Down Expand Up @@ -157,6 +156,9 @@ def mps_ops_grad_modifier(ops):
# fixed in macOS 13. We are not raising error.
'pow': [torch.float32],
'__rpow__': [torch.float32],

# See https://github.com/pytorch/pytorch/issues/106112 for more information
'cumprod': [torch.float32],
}

XPASSLIST_GRAD = {
Expand Down Expand Up @@ -336,8 +338,11 @@ def mps_ops_modifier(ops):
'sort': [torch.int8, torch.uint8, torch.bool, torch.float16],
# Unsupported dtypes
'cumsum': [torch.int64],
'cumprod': [torch.int64],
'cumulative_trapezoid': [torch.int64],
'masked.cumsum': [torch.int64],
'masked.cumprod': [torch.int64],
'linalg.vander': [torch.int64],
}

MACOS_AFTER_13_1_XFAILLIST = {
Expand Down Expand Up @@ -404,7 +409,6 @@ def mps_ops_modifier(ops):
'cholesky_solve': None,
'cummax': None,
'cummin': None,
'cumprod': None,
'digamma': None,
'erfc': None,
'frexp': None,
Expand Down Expand Up @@ -449,14 +453,12 @@ def mps_ops_modifier(ops):
'linalg.solve_ex': None,
'linalg.svdvals': None,
'linalg.tensorsolve': None,
'linalg.vander': None,
'linalg.vecdot': None,
'logcumsumexp': None,
'logdet': None,
'lu': None,
'lu_solve': None,
'lu_unpack': None,
'masked.cumprod': None,
'masked.median': None,
'matrix_exp': None,
'mode': None,
Expand Down Expand Up @@ -3666,7 +3668,8 @@ def helper(dtype):
helper(torch.int64)
except Exception as e:
e_string = str(e)
self.assertEqual(e_string, "MPS does not support cumsum op with int64 input. Support has been added in macOS 13.3")
self.assertEqual(e_string, "MPS does not support cumsum_out_mps op with int64 input." +
" Support has been added in macOS 13.3")

def test_cumsum_minus_one_axis(self):
def helper(dtype):
Expand All @@ -3685,6 +3688,41 @@ def helper(dtype):

[helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]]

def test_cumprod_all_dtypes(self):
def helper(dtype):
t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype)
t_cpu = torch.tensor([1, 1, 1, 1], device="cpu")

a = t.cumprod(0, dtype=dtype)
a_cpu = t_cpu.cumprod(0, dtype=dtype)

self.assertEqual(a.cpu(), a_cpu)
[helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]]

try:
helper(torch.int64)
except Exception as e:
e_string = str(e)
self.assertEqual(e_string, "MPS does not support cumprod_out_mps op with int64 input."
+ " Support has been added in macOS 13.3")

def test_cumprod_minus_one_axis(self):
def helper(dtype):
# Test with axis -1
cpu_x = None
if(dtype == torch.float32):
cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
else:
cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
x = cpu_x.detach().clone().to('mps')

cpu_y = cpu_x.cumprod(-1)
y = x.cumprod(-1)

self.assertEqual(y, cpu_y)

[helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]]

def test_median_int16(self):
def helper(shape, dtype):
cpu_x = torch.randint(-9999, 9999, shape, device='cpu', dtype=dtype)
Expand Down Expand Up @@ -7626,6 +7664,13 @@ def test_cumsum_dim_check(self):
self.assertRaises(IndexError, lambda: x.cumsum(2))
self.assertRaises(IndexError, lambda: x.cumsum(-3))

def test_cumprod_dim_check(self):
x = torch.rand((3, 3), device="mps")
self.assertEqual(x.cumprod(1), x.cumprod(-1))
self.assertEqual(x.cumprod(0), x.cumprod(-2))
self.assertRaises(IndexError, lambda: x.cumprod(2))
self.assertRaises(IndexError, lambda: x.cumprod(-3))


class TestTopK(TestCase):
def _test_topk(self, shape, largest):
Expand Down

0 comments on commit 97e5055

Please sign in to comment.