-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LSTM Output Transposed w/MPS on 1.13 nightly build #80306
Comments
Minimal case confirming this issue. Verified on PyTorch version import torch
from torch.nn import LSTM
assert torch.backends.mps.is_available()
tensor_shapes = {}
for device_type in ("mps", "cpu"):
device = torch.device(device_type)
lstm = LSTM(64, 128, 1, batch_first=True)
lstm.to(device)
input_tensor = torch.randn((2, 4, 64)).to(device)
output, _ = lstm(input_tensor)
tensor_shapes[device_type] = output.shape
print(tensor_shapes) # prints {'mps': torch.Size([4, 2, 128]), 'cpu': torch.Size([2, 4, 128])}
# ^^^^^ ^^^^^ Interestingly, this doesn't happen when import torch
from torch.nn import LSTM
assert torch.backends.mps.is_available()
tensor_shapes = {}
for device_type in ("mps", "cpu"):
device = torch.device(device_type)
lstm = LSTM(64, 128, 1)
lstm.to(device)
input_tensor = torch.randn((2, 4, 64)).to(device)
output, _ = lstm(input_tensor)
tensor_shapes[device_type] = output.shape
print(tensor_shapes) # prints {'mps': torch.Size([2, 4, 128]), 'cpu': torch.Size([2, 4, 128])} |
Also, this doesn't happen with GRU: import torch
from torch.nn import GRU
assert torch.backends.mps.is_available()
tensor_shapes = {}
for device_type in ("mps", "cpu"):
device = torch.device(device_type)
gru = GRU(64, 128, 1, batch_first=True)
gru.to(device)
input_tensor = torch.randn((2,4,64)).to(device)
output, _ = gru(input_tensor)
tensor_shapes[device_type] = output.shape
print(tensor_shapes) # prints {'mps': torch.Size([2, 4, 128]), 'cpu': torch.Size([2, 4, 128])} |
Should be fixed by the linked PR. |
Summary: The output of LSTM with `batch_first` should be transposed back to batch first format. Fixes #80306 Pull Request resolved: #80597 Approved by: https://github.com/kulinseth Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/b0b24b4285dae33cf8699a3b25b37675a655ba4b Reviewed By: mehtanirav Differential Revision: D37687576 Pulled By: mehtanirav fbshipit-source-id: d0837b4aa3e1f5f33cc614dd5fe14b532fb96304
The output of LSTM with `batch_first` should be transposed back to batch first format. Fixes pytorch#80306 Pull Request resolved: pytorch#80597 Approved by: https://github.com/kulinseth
Hi, The committed changes did not fix the issue. While the model's output order is now correct for This simple example will demonstrate the problem. It runs fine with
The error is:
It looks like it is still not flipping the batch (64)and sequence dimensions (10) when training. |
The output of LSTM with `batch_first` should be transposed back to batch first format. Fixes pytorch#80306 Pull Request resolved: pytorch#80597 Approved by: https://github.com/kulinseth
* MPS: Fixes (#78930) Cast integer to float in UnaryOps Add tensor dtype in key generation Enable FP16 scalars and use placeholder for alpha tensor in add/sum ops Fixes #ISSUE_NUMBER Pull Request resolved: #78930 Approved by: https://github.com/albanD * MPS: Binary cast fix by proper type promotion and remove spurious copy warning (#79185) Fixes #78019, #78020 Fixes #79185 Pull Request resolved: #79185 Approved by: https://github.com/albanD, https://github.com/razarmehr * MPS: add exponential op (#79188) Add exponential distribution Fixes #ISSUE_NUMBER Pull Request resolved: #79188 Approved by: https://github.com/razarmehr, https://github.com/albanD * [MPS] Delete unused vars from OperationUtils.mm Pull Request resolved: #79514 Approved by: https://github.com/kulinseth, https://github.com/albanD * [MPS] Fix getDefaultGenerator and copy_kernel_mps Returning reference to stack memory is really bad Pull Request resolved: #79515 Approved by: https://github.com/albanD * [MPS][BE]Do not use `new/delete[]` in `chainViewOperation` `std::array` will do just fine Pull Request resolved: #79516 Approved by: https://github.com/albanD * [MPS] Support stride of stride Fixes #79181 Pull Request resolved: #79521 Approved by: https://github.com/kulinseth * MPS: TopK raise an error if K>16 (#79677) * Error out in TopK when k>16. * Add a test case too. Fixes #78915 Pull Request resolved: #79677 Approved by: https://github.com/albanD * [MPS]: Add fix for squeezed input axes handling in BCE loss (#79676) Fixes #79527 Pull Request resolved: #79676 Approved by: https://github.com/razarmehr, https://github.com/albanD * MPS: Add amax and amin Ops with tests (#79682) * Add amax and amin with tests Fixes #ISSUE_NUMBER Pull Request resolved: #79682 Approved by: https://github.com/albanD * [MPS] Fix torch.uint8 support (#80049) `ScalarType.Byte` should be cast to `MPSDataTypeUInt8` And support for `torch.int8` as well as test those conversions in `TestMPS.test_to` Fixes #80006 Pull Request resolved: #80049 Approved by: https://github.com/albanD * [MPS] Fix binary ops between int32 tensor with int64 scalar (#80220) For some reason, tensor *op* scalar does not follow the normal binary promotion rules So cast output tensor to expected type if needed It seems that one should have casted input tensors to expected output tensor type, but it does not really work for boolean binary ops, so... Add output tensor type/shape to cached graph key Extend `TestMPS. test_add_scalars` to test for this regression Fixes #79835 Pull Request resolved: #80220 Approved by: https://github.com/albanD * [MPS] Add equal operator (#80195) Which is, in essence is composite of `eq`->`all`->`item` `native/mps/operators/Equal.cpp` is an almost verbatim copy of `native/cuda/Equal.cpp` Fix codegen by generating MPSFunctions headers Pull Request resolved: #80195 Approved by: https://github.com/albanD * [MPS] add `aten::normal.Tensor_float` `aten::normal.float_Tensor` `aten::normal.Tensor_Tensor` (#80297) Fixes #ISSUE_NUMBER Pull Request resolved: #80297 Approved by: https://github.com/albanD, https://github.com/kulinseth * [MPS] Add flip (#80214) Fixes #ISSUE_NUMBER Pull Request resolved: #80214 Approved by: https://github.com/DenisVieriu97, https://github.com/albanD * [MPS] Add logical ops (#80216) This PR adds `logical_not`, `logical_and`, `logical_or`, `logical_xor`. Pull Request resolved: #80216 Approved by: https://github.com/albanD, https://github.com/kulinseth * [MPS] Add glu (#79866) Adds mps op for `aten::glu.out`. Pull Request resolved: #79866 Approved by: https://github.com/kulinseth, https://github.com/albanD * [MPS] Fix std/var cache issue (#80502) Use `getTensorsStringKey` which has tensor shape info added as part of the key to prevent cache lookup issue when the shape of input tensor is changed. Fixes #80499 Pull Request resolved: #80502 Approved by: https://github.com/malfet, https://github.com/kulinseth * Add scatter support for view operations (#79939) * Add scatter support for view operations; #78074, #78886, #79672 * Update test_slicing_replace_column to properly test different sizes * Handle in-place changes for binary ops; add new testcase * Add new view ops testing scatter; add MPSDebugConfig.h config file for debugging purposes * Merge gatherViewTensor and scatterViewTensor into a generic function * Add scatter on demand in scatterViewOperation instead of caching it into a generic graph * Create separate graphs for scatter and gather; * Create scatter graph at scatter time Fixes #ISSUE_NUMBER Pull Request resolved: #79939 Approved by: https://github.com/razarmehr * MPS: Fix handling of 1D tensors in linear backward (#80759) Fixes ##79784 Pull Request resolved: #80759 Approved by: https://github.com/ezyang * [MPS] Move the View ops to a separate file and reduce the number of graphs created (#80491) This is dependent on the PR to go in first: #79939 Remove the data_ptr from the View Graph key which reduces the number of graphs created significantly. Don't wait when copying from MPS to MPS tensors Pull Request resolved: #80491 Approved by: https://github.com/malfet * [MPS] Add softplus backward (#79873) Fixes #ISSUE_NUMBER Pull Request resolved: #79873 Approved by: https://github.com/malfet * [MPS] Add argmin (#80828) This PR 1. adds argmin 2. refactors `reduction_type` in `ReduceOps.mm` with enum. Co-authored by Kulin Seth <kulinseth@gmail.com> Pull Request resolved: #80828 Approved by: https://github.com/malfet * [MPS] Fix LSTM batch_first output transposed (#80597) The output of LSTM with `batch_first` should be transposed back to batch first format. Fixes #80306 Pull Request resolved: #80597 Approved by: https://github.com/kulinseth * [MPS][BE] Introduce MPSUnaryCachedGraph (#81033) I.e. CachedGraph that has input and output tensors Also, add `MPSGraphCache::LookUpAs` template, which combines LookUp with static_cast to target type Pull Request resolved: #81033 Approved by: https://github.com/kulinseth * [MPS] Add test consistency from OpInfo based tests from PR 78504 (#79532) Pull Request resolved: #79532 Approved by: https://github.com/albanD, https://github.com/malfet * [MPS] Add huber loss (#80163) Fixes #ISSUE_NUMBER Pull Request resolved: #80163 Approved by: https://github.com/kulinseth, https://github.com/malfet * Remove two tests dependent on the MPS serialization checkin. * Fix lint error (FLAKE8) F401 * Remove the serialization test from test_mps as its support is not there in 1.12.1. Co-authored-by: Kulin Seth <kulinseth@gmail.com> Co-authored-by: Nikita Shulga <nikita.shulga@gmail.com> Co-authored-by: Kulin Seth <kulin_seth@apple.com> Co-authored-by: Abhishek Pathak <abhipathak97@gmail.com> Co-authored-by: Nikita Shulga <nshulga@fb.com> Co-authored-by: qqaatw <qqaatw@gmail.com> Co-authored-by: Ramin Azarmehr <razarmehr@apple.com>
LSTM backward path is still broken. (torch 1.12.1 and nightly of 2022-09-11) As @PHRABAL noted, there is most probably a single |
This is still broken, and it's hugely problematic for my use case, since our library has bidirectional LSTMs everywhere
This is using a nightly build from 2022-12-06 |
The native implementation of LSTM has been fixed on macOS 13. On macOS 12, the multi-layer LSTM still has a numerical correctness issue that cannot be resolved on OS's side. Thus, we fall back the multi-layer LSTM on macOS 12 to LSTMCell iteration. It might have performance impact but will make LSTM on macOS 12 fully usable. Fixes: #90421 Issues related: #80306, #83144 Pull Request resolved: #90909 Approved by: https://github.com/albanD, https://github.com/kulinseth
🐛 Describe the bug
The 1.13 nightly build, when sending an LSTM model to
device="mps"
reverses the expected order of batch and seq in the output.Please see this discussion for code examples and further details:
https://discuss.pytorch.org/t/lstm-output-transposed/154820/2
Versions
Collecting environment information...
PyTorch version: 1.13.0.dev20220620
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 12.4 (arm64)
GCC version: Could not collect
Clang version: 13.1.6 (clang-1316.0.21.2.5)
CMake version: Could not collect
Libc version: N/A
Python version: 3.9.13 | packaged by conda-forge | (main, May 27 2022, 17:01:00) [Clang 13.0.1 ] (64-bit runtime)
Python platform: macOS-12.4-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.22.4
[pip3] torch==1.13.0.dev20220620
[conda] numpy 1.22.4 pypi_0 pypi
[conda] pytorch 1.13.0.dev20220620 py3.9_0 pytorch-nightly
cc @kulinseth @albanD
The text was updated successfully, but these errors were encountered: