Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cumprod support for device mps #104688

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
71 changes: 51 additions & 20 deletions aten/src/ATen/native/mps/operations/UnaryOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <ATen/ops/ceil_native.h>
#include <ATen/ops/cos_native.h>
#include <ATen/ops/cosh_native.h>
#include <ATen/ops/cumprod_native.h>
#include <ATen/ops/cumsum_native.h>
#include <ATen/ops/erf_native.h>
#include <ATen/ops/exp2_native.h>
Expand Down Expand Up @@ -46,6 +47,12 @@
#endif

namespace at::native {

enum class MPSCumulativeOpType : uint8_t {
CUMSUM = 0,
CUMPROD = 1,
};

namespace mps {

typedef MPSGraphTensor* (^UnaryOpBlock)(MPSGraph*, MPSGraphTensor*);
Expand Down Expand Up @@ -375,8 +382,12 @@ Tensor logit_mps(const Tensor& self, c10::optional<double> eps) {
}
}

TORCH_IMPL_FUNC(cumsum_out_mps)
(const Tensor& self, int64_t dim, c10::optional<ScalarType> dtype, const Tensor& result) {
void cumulative_op_impl(const Tensor& self,
int64_t dim,
c10::optional<ScalarType> dtype,
const Tensor& result,
MPSCumulativeOpType cumulativeOpType,
const std::string& op_name) {
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
auto nDims = self.dim();
auto wrapped_dim = maybe_wrap_dim(dim, nDims);
Expand All @@ -389,34 +400,54 @@ Tensor logit_mps(const Tensor& self, c10::optional<double> eps) {
dim,
")");
if (!is_macos_13_or_newer()) {
TORCH_WARN_ONCE("torch.cumsum supported by MPS on MacOS 13+, please upgrade");
auto cpu_result = self.to(at::Device(kCPU)).cumsum(dim, dtype);
TORCH_WARN_ONCE(op_name, " supported by MPS on MacOS 13+, please upgrade");
Tensor cpu_result;
if (cumulativeOpType == MPSCumulativeOpType::CUMSUM) {
cpu_result = self.to(at::Device(kCPU)).cumsum(dim, dtype);
} else if (cumulativeOpType == MPSCumulativeOpType::CUMPROD) {
cpu_result = self.to(at::Device(kCPU)).cumprod(dim, dtype);
}
at::_copy_from_and_resize(cpu_result, result);
return;
}
auto input = dtype.has_value() ? self.to(dtype.value()) : self;

// issue #103810551: cumsum is horribly broken for int8, int16 and as chances for overflow is pretty high, cast to
// issue #103810551: cumsum / cumprod are broken for int8, int16 and as chances for overflow are pretty high, cast to
// int32 fixed in macOS 13.3
bool castInputData = (isIntegralType(input.scalar_type(), false) && input.scalar_type() != ScalarType::Int &&
input.scalar_type() != ScalarType::Long);

TORCH_CHECK(macOS13_3_plus || input.scalar_type() != ScalarType::Long,
"MPS does not support cumsum op with int64 input. Support has been added in macOS 13.3");

mps::unary_op(input,
result,
"cumsum_out_mp" + std::to_string(dim),
^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
if (castInputData) {
inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, ScalarType::Int);
}
auto rc = [mpsGraph cumulativeSumWithTensor:inputTensor axis:dim name:nil];
if ((mps::getMPSDataType(result) != [rc dataType]) || castInputData) {
return mps::castMPSTensor(mpsGraph, rc, result.scalar_type());
}
return rc;
});
"MPS does not support ",
op_name,
" op with int64 input. Support has been added in macOS 13.3");

mps::unary_op(
input, result, op_name + std::to_string(dim), ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
if (castInputData) {
inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, ScalarType::Int);
}
MPSGraphTensor* rc;
if (cumulativeOpType == MPSCumulativeOpType::CUMSUM) {
rc = [mpsGraph cumulativeSumWithTensor:inputTensor axis:dim name:nil];
} else if (cumulativeOpType == MPSCumulativeOpType::CUMPROD) {
rc = [mpsGraph cumulativeProductWithTensor:inputTensor axis:dim name:nil];
}
if ((mps::getMPSDataType(result) != [rc dataType]) || castInputData) {
return mps::castMPSTensor(mpsGraph, rc, result.scalar_type());
}
return rc;
});
}

TORCH_IMPL_FUNC(cumsum_out_mps)
(const Tensor& self, int64_t dim, c10::optional<ScalarType> dtype, const Tensor& result) {
return cumulative_op_impl(self, dim, dtype, result, MPSCumulativeOpType::CUMSUM, "cumsum_out_mps");
}

TORCH_IMPL_FUNC(cumprod_out_mps)
(const Tensor& self, int64_t dim, c10::optional<ScalarType> dtype, const Tensor& result) {
return cumulative_op_impl(self, dim, dtype, result, MPSCumulativeOpType::CUMPROD, "cumprod_out_mps");
}

} // namespace at::native
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1899,6 +1899,7 @@
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: cumprod_out
MPS: cumprod_out_mps

- func: cumprod.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # TensorIterator
Expand Down
55 changes: 50 additions & 5 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def mps_ops_grad_modifier(ops):

# Unimplemented ops
'__getitem__': [torch.float16],
'prod': [torch.float32], # The operator 'aten::cumprod.out'
'sgn': [torch.float16, torch.float32],
'_segment_reduce': [torch.float16, torch.float32],
'unfold_copy': [torch.float16, torch.float32], # unfold_backward is not implemented
Expand Down Expand Up @@ -157,6 +156,9 @@ def mps_ops_grad_modifier(ops):
# fixed in macOS 13. We are not raising error.
'pow': [torch.float32],
'__rpow__': [torch.float32],

# See https://github.com/pytorch/pytorch/issues/106112 for more information
'cumprod': [torch.float32],
}

XPASSLIST_GRAD = {
Expand Down Expand Up @@ -336,8 +338,11 @@ def mps_ops_modifier(ops):
'sort': [torch.int8, torch.uint8, torch.bool, torch.float16],
# Unsupported dtypes
'cumsum': [torch.int64],
'cumprod': [torch.int64],
'cumulative_trapezoid': [torch.int64],
'masked.cumsum': [torch.int64],
'masked.cumprod': [torch.int64],
'linalg.vander': [torch.int64],
}

MACOS_AFTER_13_1_XFAILLIST = {
Expand Down Expand Up @@ -404,7 +409,6 @@ def mps_ops_modifier(ops):
'cholesky_solve': None,
'cummax': None,
'cummin': None,
'cumprod': None,
'digamma': None,
'erfc': None,
'frexp': None,
Expand Down Expand Up @@ -449,14 +453,12 @@ def mps_ops_modifier(ops):
'linalg.solve_ex': None,
'linalg.svdvals': None,
'linalg.tensorsolve': None,
'linalg.vander': None,
'linalg.vecdot': None,
'logcumsumexp': None,
'logdet': None,
'lu': None,
'lu_solve': None,
'lu_unpack': None,
'masked.cumprod': None,
'masked.median': None,
'matrix_exp': None,
'mode': None,
Expand Down Expand Up @@ -3666,7 +3668,8 @@ def helper(dtype):
helper(torch.int64)
except Exception as e:
e_string = str(e)
self.assertEqual(e_string, "MPS does not support cumsum op with int64 input. Support has been added in macOS 13.3")
self.assertEqual(e_string, "MPS does not support cumsum_out_mps op with int64 input." +
" Support has been added in macOS 13.3")

def test_cumsum_minus_one_axis(self):
def helper(dtype):
Expand All @@ -3685,6 +3688,41 @@ def helper(dtype):

[helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]]

def test_cumprod_all_dtypes(self):
def helper(dtype):
t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype)
t_cpu = torch.tensor([1, 1, 1, 1], device="cpu")

a = t.cumprod(0, dtype=dtype)
a_cpu = t_cpu.cumprod(0, dtype=dtype)

self.assertEqual(a.cpu(), a_cpu)
[helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]]

try:
helper(torch.int64)
except Exception as e:
e_string = str(e)
self.assertEqual(e_string, "MPS does not support cumprod_out_mps op with int64 input."
+ " Support has been added in macOS 13.3")

def test_cumprod_minus_one_axis(self):
def helper(dtype):
# Test with axis -1
cpu_x = None
if(dtype == torch.float32):
cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
else:
cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
x = cpu_x.detach().clone().to('mps')

cpu_y = cpu_x.cumprod(-1)
y = x.cumprod(-1)

self.assertEqual(y, cpu_y)

[helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]]

def test_median_int16(self):
def helper(shape, dtype):
cpu_x = torch.randint(-9999, 9999, shape, device='cpu', dtype=dtype)
Expand Down Expand Up @@ -7626,6 +7664,13 @@ def test_cumsum_dim_check(self):
self.assertRaises(IndexError, lambda: x.cumsum(2))
self.assertRaises(IndexError, lambda: x.cumsum(-3))

def test_cumprod_dim_check(self):
x = torch.rand((3, 3), device="mps")
self.assertEqual(x.cumprod(1), x.cumprod(-1))
self.assertEqual(x.cumprod(0), x.cumprod(-2))
self.assertRaises(IndexError, lambda: x.cumprod(2))
self.assertRaises(IndexError, lambda: x.cumprod(-3))


class TestTopK(TestCase):
def _test_topk(self, shape, largest):
Expand Down