Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cumprod support for device mps #104688

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 46 additions & 1 deletion aten/src/ATen/native/mps/operations/UnaryOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <ATen/ops/cos_native.h>
#include <ATen/ops/cosh_native.h>
#include <ATen/ops/cumsum_native.h>
#include <ATen/ops/cumprod_native.h>
#include <ATen/ops/erf_native.h>
#include <ATen/ops/exp2_native.h>
#include <ATen/ops/exp_native.h>
Expand Down Expand Up @@ -393,7 +394,7 @@ Tensor logit_mps(const Tensor& self, c10::optional<double> eps) {

mps::unary_op(input,
result,
"cumsum_out_mp" + std::to_string(dim),
"cumsum_out_mps" + std::to_string(dim),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drive by spelling fix

^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
if (castInputData) {
inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, ScalarType::Int);
Expand All @@ -406,4 +407,48 @@ Tensor logit_mps(const Tensor& self, c10::optional<double> eps) {
});
}

TORCH_IMPL_FUNC(cumprod_out_mps)
(const Tensor& self, int64_t dim, c10::optional<ScalarType> dtype, const Tensor& result) {
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
auto nDims = self.dim();
auto wrapped_dim = maybe_wrap_dim(dim, nDims);
TORCH_CHECK(wrapped_dim >= 0 && wrapped_dim < std::max(1LL, self.ndimension()),
"Expected wrapped dim to be between 0 and ",
self.ndimension(),
" but got ",
wrapped_dim,
"(original dim is ",
dim,
")");
if (!is_macos_13_or_newer()) {
TORCH_WARN_ONCE("torch.cumprod supported by MPS on MacOS 13+, please upgrade");
auto cpu_result = self.to(at::Device(kCPU)).cumprod(dim, dtype);
at::_copy_from_and_resize(cpu_result, result);
return;
}
auto input = dtype.has_value() ? self.to(dtype.value()) : self;

// issue #103810551: cumprod is horribly broken for int8, int16 and as chances for overflow is pretty high, cast to
Copy link
Contributor Author

@Mr4k Mr4k Jul 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what issue #103810551 is and where I want to check if this applies to cumprod as well as cumsum. So far I assumed it did but seeing the actual issue in whatever tracker it lies would help. If it didn't I could remove a decent chunk of this function. If it does not I can dedupe this with cumsum and make a general helper function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from the work done on kulinseth#419, it appears constraint does hold for both functions so I will merge them into a single one as is done in this pr

// int32 fixed in macOS 13.3
bool castInputData = (isIntegralType(input.scalar_type(), false) && input.scalar_type() != ScalarType::Int &&
input.scalar_type() != ScalarType::Long);

TORCH_CHECK(macOS13_3_plus || input.scalar_type() != ScalarType::Long,
"MPS does not support cumprod op with int64 input. Support has been added in macOS 13.3");

mps::unary_op(input,
result,
"cumprod_out_mps" + std::to_string(dim),
^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
if (castInputData) {
inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, ScalarType::Int);
}
auto rc = [mpsGraph cumulativeProductWithTensor:inputTensor axis:dim name:nil];
if ((mps::getMPSDataType(result) != [rc dataType]) || castInputData) {
return mps::castMPSTensor(mpsGraph, rc, result.scalar_type());
}
return rc;
});
}

} // namespace at::native
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1899,6 +1899,7 @@
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: cumprod_out
MPS: cumprod_out_mps

- func: cumprod.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # TensorIterator
Expand Down
86 changes: 82 additions & 4 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def mps_ops_grad_modifier(ops):

# Unimplemented ops
'__getitem__': [torch.float16],
'prod': [torch.float32], # The operator 'aten::cumprod.out'
'sgn': [torch.float16, torch.float32],
'_segment_reduce': [torch.float16, torch.float32],
'unfold_copy': [torch.float16, torch.float32], # unfold_backward is not implemented
Expand Down Expand Up @@ -137,6 +136,7 @@ def mps_ops_grad_modifier(ops):
# fixed in macOS 13. We are not raising error.
'__rpow__': [torch.float32],
'pow': [torch.float32],
'prod': [torch.float32], # The operator 'aten::cumprod.out' is not supported until macOS 13
}

MACOS_BEFORE_13_3_XFAILLIST_GRAD = {
Expand Down Expand Up @@ -337,8 +337,10 @@ def mps_ops_modifier(ops):
'sort': [torch.int8, torch.uint8, torch.bool, torch.float16],
# Unsupported dtypes
'cumsum': [torch.int64],
'cumprod': [torch.int64],
'cumulative_trapezoid': [torch.int64],
'masked.cumsum': [torch.int64],
'masked.cumprod': [torch.int64],
}

MACOS_AFTER_13_1_XFAILLIST = {
Expand Down Expand Up @@ -405,7 +407,6 @@ def mps_ops_modifier(ops):
'cholesky_solve': None,
'cummax': None,
'cummin': None,
'cumprod': None,
'digamma': None,
'erfc': None,
'erfinv': None,
Expand Down Expand Up @@ -452,14 +453,12 @@ def mps_ops_modifier(ops):
'linalg.solve_ex': None,
'linalg.svdvals': None,
'linalg.tensorsolve': None,
'linalg.vander': None,
'linalg.vecdot': None,
'logcumsumexp': None,
'logdet': None,
'lu': None,
'lu_solve': None,
'lu_unpack': None,
'masked.cumprod': None,
'masked.median': None,
'matrix_exp': None,
'mode': None,
Expand Down Expand Up @@ -3682,6 +3681,54 @@ def helper(dtype):

[helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]]

def test_cumprod_all_dtypes(self):
def helper(dtype):
t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype)
t_cpu = torch.tensor([1, 1, 1, 1], device="cpu")

a = t.cumprod(0, dtype=dtype)
a_cpu = t_cpu.cumprod(0, dtype=dtype)

self.assertEqual(a.cpu(), a_cpu)
[helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]]

try:
helper(torch.int64)
except Exception as e:
e_string = str(e)
self.assertEqual(e_string, "MPS does not support cumprod op with int64 input. Support has been added in macOS 13.3")

def test_cumprod_minus_one_axis(self):
def helper(dtype):
# Test with axis -1
cpu_x = None
if(dtype == torch.float32):
cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
else:
cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
x = cpu_x.detach().clone().to('mps')

cpu_y = cpu_x.cumprod(-1)
y = x.cumprod(-1)

self.assertEqual(y, cpu_y)

[helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]]

@unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
def test_cumprod_backward(self):
# the gradient computation of cumprod has two different paths depending on whether or not multiple derivatives are to be computed
for multiple_derivatives in [False, True]:
Mr4k marked this conversation as resolved.
Show resolved Hide resolved
t = torch.tensor([1.0, 2.0, 3.0, 4.0], device="mps", dtype=torch.float).detach().requires_grad_()
t_cpu = torch.tensor([1.0, 2.0, 3.0, 4.0], device="cpu", dtype=torch.float).detach().requires_grad_()

gradient=torch.full_like(t, 2)
gradient_cpu=torch.full_like(t_cpu, 2)
self.assertEqual(
torch.autograd.grad(t.cumprod(0), t, grad_outputs=gradient, create_graph=multiple_derivatives),
torch.autograd.grad(t_cpu.cumprod(0), t_cpu, grad_outputs=gradient_cpu, create_graph=multiple_derivatives)
)

def test_median_int16(self):
def helper(shape, dtype):
cpu_x = torch.randint(-9999, 9999, shape, device='cpu', dtype=dtype)
Expand Down Expand Up @@ -5175,6 +5222,30 @@ def helper(shape, dtype=torch.float32):
for dtype in [torch.float32, torch.int32, torch.int64, torch.bool]:
helper((2, 3), dtype)

# torch.cumprod is not supported on mps until macOS 13
@unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
def test_prod_backward(self):
t = torch.tensor([1.0, 2.0, 3.0, 4.0], device="mps", dtype=torch.float).detach().requires_grad_()
t_cpu = torch.tensor([1.0, 2.0, 3.0, 4.0], device="cpu", dtype=torch.float).detach().requires_grad_()

gradient=torch.full_like(t, 2)
gradient_cpu=torch.full_like(t_cpu, 2)
self.assertEqual(
torch.autograd.grad(t.cumprod(0), t, grad_outputs=gradient, create_graph=False),
torch.autograd.grad(t_cpu.cumprod(0), t_cpu, grad_outputs=gradient_cpu, create_graph=False)
)

# when there is exactly one zero in the tensor the gradient of prod involves cumprod
t = torch.tensor([1.0, 2.0, 0.0, 3.0, 4.0], device="mps", dtype=torch.float).detach().requires_grad_()
t_cpu = torch.tensor([1.0, 2.0, 0.0, 3.0, 4.0], device="cpu", dtype=torch.float).detach().requires_grad_()

gradient=torch.full_like(t, 2)
gradient_cpu=torch.full_like(t_cpu, 2)
self.assertEqual(
torch.autograd.grad(t.cumprod(0), t, grad_outputs=gradient, create_graph=False),
torch.autograd.grad(t_cpu.cumprod(0), t_cpu, grad_outputs=gradient_cpu, create_graph=False)
)

# Test forward mean
def test_mean(self):
def helper(n, c, h, w):
Expand Down Expand Up @@ -7568,6 +7639,13 @@ def test_cumsum_dim_check(self):
self.assertRaises(IndexError, lambda: x.cumsum(2))
self.assertRaises(IndexError, lambda: x.cumsum(-3))

def test_cumprod_dim_check(self):
x = torch.rand((3, 3), device="mps")
self.assertEqual(x.cumprod(1), x.cumprod(-1))
self.assertEqual(x.cumprod(0), x.cumprod(-2))
self.assertRaises(IndexError, lambda: x.cumprod(2))
self.assertRaises(IndexError, lambda: x.cumprod(-3))


class TestTopK(TestCase):
def _test_topk(self, shape, largest):
Expand Down