-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add cumprod support for device mps #104688
Changes from 6 commits
5a7abba
d1ad744
2f3fd26
6469521
abd0638
065e060
b0c6122
8dec7e6
6cefe4d
382b758
6345dd5
8bd1740
58bb641
5d50a78
5311353
fb24d18
f2cd347
2eafe33
f3c73f0
54ffc9e
41b9d63
13ff5a3
cca389b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
#include <ATen/ops/cos_native.h> | ||
#include <ATen/ops/cosh_native.h> | ||
#include <ATen/ops/cumsum_native.h> | ||
#include <ATen/ops/cumprod_native.h> | ||
#include <ATen/ops/erf_native.h> | ||
#include <ATen/ops/exp2_native.h> | ||
#include <ATen/ops/exp_native.h> | ||
|
@@ -393,7 +394,7 @@ Tensor logit_mps(const Tensor& self, c10::optional<double> eps) { | |
|
||
mps::unary_op(input, | ||
result, | ||
"cumsum_out_mp" + std::to_string(dim), | ||
"cumsum_out_mps" + std::to_string(dim), | ||
^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { | ||
if (castInputData) { | ||
inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, ScalarType::Int); | ||
|
@@ -406,4 +407,48 @@ Tensor logit_mps(const Tensor& self, c10::optional<double> eps) { | |
}); | ||
} | ||
|
||
TORCH_IMPL_FUNC(cumprod_out_mps) | ||
(const Tensor& self, int64_t dim, c10::optional<ScalarType> dtype, const Tensor& result) { | ||
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); | ||
auto nDims = self.dim(); | ||
auto wrapped_dim = maybe_wrap_dim(dim, nDims); | ||
TORCH_CHECK(wrapped_dim >= 0 && wrapped_dim < std::max(1LL, self.ndimension()), | ||
"Expected wrapped dim to be between 0 and ", | ||
self.ndimension(), | ||
" but got ", | ||
wrapped_dim, | ||
"(original dim is ", | ||
dim, | ||
")"); | ||
if (!is_macos_13_or_newer()) { | ||
TORCH_WARN_ONCE("torch.cumprod supported by MPS on MacOS 13+, please upgrade"); | ||
auto cpu_result = self.to(at::Device(kCPU)).cumprod(dim, dtype); | ||
at::_copy_from_and_resize(cpu_result, result); | ||
return; | ||
} | ||
auto input = dtype.has_value() ? self.to(dtype.value()) : self; | ||
|
||
// issue #103810551: cumprod is horribly broken for int8, int16 and as chances for overflow is pretty high, cast to | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure what There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. from the work done on kulinseth#419, it appears constraint does hold for both functions so I will merge them into a single one as is done in this pr |
||
// int32 fixed in macOS 13.3 | ||
bool castInputData = (isIntegralType(input.scalar_type(), false) && input.scalar_type() != ScalarType::Int && | ||
input.scalar_type() != ScalarType::Long); | ||
|
||
TORCH_CHECK(macOS13_3_plus || input.scalar_type() != ScalarType::Long, | ||
"MPS does not support cumprod op with int64 input. Support has been added in macOS 13.3"); | ||
|
||
mps::unary_op(input, | ||
result, | ||
"cumprod_out_mps" + std::to_string(dim), | ||
^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { | ||
if (castInputData) { | ||
inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, ScalarType::Int); | ||
} | ||
auto rc = [mpsGraph cumulativeProductWithTensor:inputTensor axis:dim name:nil]; | ||
if ((mps::getMPSDataType(result) != [rc dataType]) || castInputData) { | ||
return mps::castMPSTensor(mpsGraph, rc, result.scalar_type()); | ||
} | ||
return rc; | ||
}); | ||
} | ||
|
||
} // namespace at::native |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
drive by spelling fix