-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
unary ops on strided int tensors fails on MPS #105284
Comments
Okay it does look like resolving this issue is critical to merging #104688 because the test runner crashes when hitting this error so I can't just add variants of |
@Mr4k changed title a bit, as it works fine for regular unary ops, so And here is a one line reproducer:
|
@malfet thanks for checking this out and thanks for the elegant repro and issue naming help! Did you get a chance to try my other repo. I tried pulling the latest main and I got the following errors for non cumsum unary operators using your repro. It's possible that my pytorch build process is not correct and these other repros are a side effect of that but it's seemed to have worked so far:
gives me:
gives me:
gives me:
version info:
If these do repro I'm curious why the tensor is set to the output scalar type instead of input and would be happy to make a pr changing that tomorrow if it's a good idea. However I may be way off base as I'm pretty new to the code and this would be a good learning experience! |
@Mr4k thank you for the investigation, looks like I've found the problem: there is a conflict between reshaping tensor to the output type and casting input tensor to a different dtype... diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm
index 2bbd238b1e7..431d97f0836 100644
--- a/aten/src/ATen/native/mps/operations/UnaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm
@@ -84,7 +84,7 @@ void unary_op(const Tensor& self,
// If self is densely mapped in storage, create a dense output-like representation
at::Tensor self_;
if (!is_dense_in_storage(self)) {
- self_ = at::empty_like(output);
+ self_ = at::empty_like(output, self.scalar_type());
mps::mps_copy_(self_, self, false);
} else {
self_ = self; |
馃悰 Describe the bug
I encountered this problem when I pulled main and saw a new failing test around
torch.vander
in my pr #104688 (the pr technically enablestorch.vander
as a side effect which is why I was testing it). However everything below is about main to isolate it from my changes.currently results in:
The same thing happens with other unary operators as well:
results in
I believe this error is due to this line in this recent pr. By changing the line:
to
I resolved this issue. The reason this is a problem is because the logic in cumsum currently assumes the input scalar type is represented by the scalar type of the input tensor not the output tensor. See here. This logic makes me think that the previously mentioned line:
is incorrect. However I don't know the intention behind adding that line and if it is correct I can just alter the cumsum logic in my pr #104688. So I wanted to file a separate issue to discuss this specifically without tying it to #104688
Versions
Collecting environment information...
PyTorch version: 2.1.0.dev20230526
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 13.3.1 (arm64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.202)
CMake version: version 3.24.2
Libc version: N/A
Python version: 3.9.12 (main, Apr 5 2022, 01:52:34) [Clang 12.0.0 ] (64-bit runtime)
Python platform: macOS-13.3.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1 Pro
Versions of relevant libraries:
[pip3] numpy==1.24.3
[pip3] torch==2.1.0.dev20230526
[pip3] torchaudio==2.1.0.dev20230527
[pip3] torchvision==0.16.0.dev20230526
[conda] numpy 1.24.3 pypi_0 pypi
[conda] torch 2.1.0.dev20230526 pypi_0 pypi
[conda] torchaudio 2.1.0.dev20230527 pypi_0 pypi
[conda] torchvision 0.16.0.dev20230526 pypi_0 pypi
cc @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev
The text was updated successfully, but these errors were encountered: