-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MPS] Bug on training CNN+LSTM #83144
Comments
Can you please provide a minimal repro? Also, did you set |
the model is rather complicated, I will try to prepare a short demo on it, but the error was exactly like here (#78429), can take this as reference, due to this error message, is hard for me to infer whether it is a bad nn.LayerNorm or nn.LSTM. For the second question, yes, I indeed set 'batch_first = True' |
the same script running on v1.12.0 was fine, but I have manually adjust the result matrix shape from nn.LSTM, mpsbackend() is different than cuDNNbackend() |
The problem is that the backward of LSTM on MPS backend has a computational correctness issue that hasn't been resolved yet, and it currently doesn't correctly take care of Now, the Here is a related issue: #80306 |
@dominicshanshan could you please try a newer version of PyTorch, such as the latest nightly? If using pip, you can use: |
yep, but can i choose to build the nightly version from |
@DenisVieriu97 , just tried on latest nightly, still have the same error as v1.12.1
env info: OS: macOS 12.5 (arm64) Python version: 3.10.4 (main, Mar 31 2022, 03:37:37) [Clang 12.0.0 ] (64-bit runtime) Versions of relevant libraries: |
@dominicshanshan thanks for trying latest nightly and for the update! |
busy these days, sorry for reply late. The model is kind of private, but I will try to provide a toy code for you |
|
The native implementation of LSTM has been fixed on macOS 13. On macOS 12, the multi-layer LSTM still has a numerical correctness issue that cannot be resolved on OS's side. Thus, we fall back the multi-layer LSTM on macOS 12 to LSTMCell iteration. It might have performance impact but will make LSTM on macOS 12 fully usable. Fixes: #90421 Issues related: #80306, #83144 Pull Request resolved: #90909 Approved by: https://github.com/albanD, https://github.com/kulinseth
馃悰 Describe the bug
Following training on M1MAX GPU
when I training a CNN+LSTM model on Pytorch v1.12.1, it goes with this error
loc("total derivative last state"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/20d6c351-ee94-11ec-bcaf-7247572f23b4/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":219:0)): error: input types 'tensor<1x82x64xf32>' and 'tensor<1x32x64xf32>' are not broadcast compatible
LLVM ERROR: Failed to infer result type(s).
this does not happened on previous Pytorch V11.2.0, I guess something wrong with new LSTM result matrix transformation?
Versions
Collecting environment information...
PyTorch version: 1.12.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 12.5 (arm64)
GCC version: Could not collect
Clang version: 13.1.6 (clang-1316.0.21.2.5)
CMake version: Could not collect
Libc version: N/A
Python version: 3.10.5 | packaged by conda-forge | (main, Jun 14 2022, 07:07:06) [Clang 13.0.1 ] (64-bit runtime)
Python platform: macOS-12.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.23.1
[pip3] torch==1.12.1
[pip3] torchaudio==0.12.1
[pip3] torchinfo==1.7.0
[pip3] torchvision==0.13.1
[conda] numpy 1.23.1 py310h220015d_0
[conda] numpy-base 1.23.1 py310h742c864_0
[conda] pytorch 1.12.1 py3.10_0 pytorch
[conda] torchaudio 0.12.1 py310_cpu pytorch
[conda] torchinfo 1.7.0 pyhd8ed1ab_0 conda-forge
[conda] torchvision 0.13.1 py310_cpu pytorch
cc @kulinseth @albanD
The text was updated successfully, but these errors were encountered: