Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LSTM Output Transposed w/MPS on 1.13 nightly build #80306

Closed
PHRABAL opened this issue Jun 26, 2022 · 7 comments
Closed

LSTM Output Transposed w/MPS on 1.13 nightly build #80306

PHRABAL opened this issue Jun 26, 2022 · 7 comments
Assignees
Labels
module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@PHRABAL
Copy link

PHRABAL commented Jun 26, 2022

🐛 Describe the bug

The 1.13 nightly build, when sending an LSTM model to device="mps" reverses the expected order of batch and seq in the output.

Please see this discussion for code examples and further details:

https://discuss.pytorch.org/t/lstm-output-transposed/154820/2

Versions

Collecting environment information...
PyTorch version: 1.13.0.dev20220620
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 12.4 (arm64)
GCC version: Could not collect
Clang version: 13.1.6 (clang-1316.0.21.2.5)
CMake version: Could not collect
Libc version: N/A

Python version: 3.9.13 | packaged by conda-forge | (main, May 27 2022, 17:01:00) [Clang 13.0.1 ] (64-bit runtime)
Python platform: macOS-12.4-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.22.4
[pip3] torch==1.13.0.dev20220620
[conda] numpy 1.22.4 pypi_0 pypi
[conda] pytorch 1.13.0.dev20220620 py3.9_0 pytorch-nightly

cc @kulinseth @albanD

@mruberry mruberry added module: mps Related to Apple Metal Performance Shaders framework triage review labels Jun 27, 2022
@gscalise
Copy link

gscalise commented Jun 27, 2022

Minimal case confirming this issue. Verified on PyTorch version 1.13.0.dev20220627, Python version '3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:16) \n[Clang 12.0.1 ]'

import torch
from torch.nn import LSTM

assert torch.backends.mps.is_available()

tensor_shapes = {}

for device_type in ("mps", "cpu"):
    device = torch.device(device_type)
    lstm = LSTM(64, 128, 1, batch_first=True)
    lstm.to(device)
    input_tensor = torch.randn((2, 4, 64)).to(device)
    output, _ = lstm(input_tensor)
    tensor_shapes[device_type] = output.shape
    
print(tensor_shapes) # prints {'mps': torch.Size([4, 2, 128]), 'cpu': torch.Size([2, 4, 128])}
                     #                            ^^^^^                           ^^^^^

Interestingly, this doesn't happen when batch_first is False:

import torch
from torch.nn import LSTM

assert torch.backends.mps.is_available()

tensor_shapes = {}

for device_type in ("mps", "cpu"):
    device = torch.device(device_type)
    lstm = LSTM(64, 128, 1)
    lstm.to(device)
    input_tensor = torch.randn((2, 4, 64)).to(device)
    output, _ = lstm(input_tensor)
    tensor_shapes[device_type] = output.shape
    
print(tensor_shapes) # prints {'mps': torch.Size([2, 4, 128]), 'cpu': torch.Size([2, 4, 128])}

@gscalise
Copy link

Also, this doesn't happen with GRU:

import torch
from torch.nn import GRU

assert torch.backends.mps.is_available()

tensor_shapes = {}

for device_type in ("mps", "cpu"):
    device = torch.device(device_type)
    gru = GRU(64, 128, 1, batch_first=True)
    gru.to(device)
    input_tensor = torch.randn((2,4,64)).to(device)
    output, _ = gru(input_tensor)
    tensor_shapes[device_type] = output.shape
    
print(tensor_shapes) # prints {'mps': torch.Size([2, 4, 128]), 'cpu': torch.Size([2, 4, 128])}

@mruberry mruberry added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Jun 27, 2022
@qqaatw
Copy link
Collaborator

qqaatw commented Jun 30, 2022

Should be fixed by the linked PR.

facebook-github-bot pushed a commit that referenced this issue Jul 8, 2022
Summary:
The output of LSTM with `batch_first` should be transposed back to batch first format.

Fixes #80306

Pull Request resolved: #80597
Approved by: https://github.com/kulinseth

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/b0b24b4285dae33cf8699a3b25b37675a655ba4b

Reviewed By: mehtanirav

Differential Revision: D37687576

Pulled By: mehtanirav

fbshipit-source-id: d0837b4aa3e1f5f33cc614dd5fe14b532fb96304
kulinseth pushed a commit to kulinseth/pytorch that referenced this issue Jul 9, 2022
The output of LSTM with `batch_first` should be transposed back to batch first format.

Fixes pytorch#80306

Pull Request resolved: pytorch#80597
Approved by: https://github.com/kulinseth
@PHRABAL
Copy link
Author

PHRABAL commented Jul 9, 2022

Hi,

The committed changes did not fix the issue.

While the model's output order is now correct for batch_first=True, when you send it thru backprop it results in an error.

This simple example will demonstrate the problem. It runs fine with batch_first=False, but when changed to batch_first=True, it results in the error shown below.

import torch
import torch.nn as nn


device = torch.device("mps")
print(device)

batch_size = 64
seq_len = 10
input_size = 128
hidden_size = 32

class Model(nn.Module):

    def __init__(self):
        super(Model, self).__init__()
    
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=1, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, batch):
        x, _ = self.lstm(batch)
        x = self.fc(x)
        return x              

model = Model().to(device)
loss_func = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=.0001)

x = torch.randn(batch_size, seq_len, input_size).to(device)
y = torch.randn(batch_size, seq_len, 1).to(device)

print(x.shape)
print(y.shape)

for epoch in range(1, 10):
    out = model(x)
    loss = loss_func(out, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(loss.data.cpu())

print(out.shape)

The error is:

loc("total derivative last state"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/b6051351-c030-11ec-96e9-3e7866fcf3a1/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":219:0)): error: input types 'tensor<1x10x32xf32>' and 'tensor<1x64x32xf32>' are not broadcast compatible LLVM ERROR: Failed to infer result type(s). Abort trap: 6

It looks like it is still not flipping the batch (64)and sequence dimensions (10) when training.

@qqaatw
Copy link
Collaborator

qqaatw commented Jul 9, 2022

@PHRABAL Hi, the backward of LSTM on mps has issues before the fix:

@unittest.skipIf(True, "Backward of lstm returns wrong result")

So currently only the forward pass is supported and tested.

atalman pushed a commit to atalman/pytorch that referenced this issue Jul 22, 2022
The output of LSTM with `batch_first` should be transposed back to batch first format.

Fixes pytorch#80306

Pull Request resolved: pytorch#80597
Approved by: https://github.com/kulinseth
atalman added a commit that referenced this issue Jul 25, 2022
* MPS: Fixes (#78930)

Cast integer to float in UnaryOps
Add tensor dtype in key generation
Enable FP16 scalars and use placeholder for alpha tensor in add/sum ops

Fixes #ISSUE_NUMBER

Pull Request resolved: #78930
Approved by: https://github.com/albanD

* MPS: Binary cast fix by proper type promotion and remove spurious copy warning (#79185)

Fixes #78019, #78020
Fixes #79185
Pull Request resolved: #79185
Approved by: https://github.com/albanD, https://github.com/razarmehr

* MPS: add exponential op (#79188)

Add exponential distribution

Fixes #ISSUE_NUMBER

Pull Request resolved: #79188
Approved by: https://github.com/razarmehr, https://github.com/albanD

* [MPS] Delete unused vars from OperationUtils.mm

Pull Request resolved: #79514

Approved by: https://github.com/kulinseth, https://github.com/albanD

* [MPS] Fix getDefaultGenerator and copy_kernel_mps

Returning reference to stack memory is really bad

Pull Request resolved: #79515

Approved by: https://github.com/albanD

* [MPS][BE]Do not use `new/delete[]` in `chainViewOperation`

`std::array` will do just fine

Pull Request resolved: #79516

Approved by: https://github.com/albanD

* [MPS] Support stride of stride

Fixes #79181

Pull Request resolved: #79521

Approved by: https://github.com/kulinseth

* MPS: TopK raise an error if K>16 (#79677)

* Error out in TopK when k>16.
* Add a test case too.

Fixes #78915

Pull Request resolved: #79677
Approved by: https://github.com/albanD

* [MPS]: Add fix for squeezed input axes handling in BCE loss (#79676)

Fixes #79527

Pull Request resolved: #79676
Approved by: https://github.com/razarmehr, https://github.com/albanD

* MPS: Add amax and amin Ops with tests  (#79682)

* Add amax and amin with tests

Fixes #ISSUE_NUMBER

Pull Request resolved: #79682
Approved by: https://github.com/albanD

* [MPS] Fix torch.uint8 support (#80049)

`ScalarType.Byte` should be cast to `MPSDataTypeUInt8`
And support for `torch.int8` as well as test those conversions in `TestMPS.test_to`

Fixes #80006

Pull Request resolved: #80049
Approved by: https://github.com/albanD

* [MPS] Fix binary ops between int32 tensor with int64 scalar (#80220)

For some reason, tensor *op* scalar does not follow the normal binary promotion rules
So cast output tensor to expected type if needed
It seems that one should have casted input tensors to expected output tensor type, but it does not really work for boolean binary ops, so...
Add output tensor type/shape to cached graph key
Extend `TestMPS. test_add_scalars` to test for this regression

Fixes #79835

Pull Request resolved: #80220
Approved by: https://github.com/albanD

* [MPS] Add equal operator (#80195)

Which is, in essence is composite of `eq`->`all`->`item`
`native/mps/operators/Equal.cpp` is an almost verbatim copy of `native/cuda/Equal.cpp`

Fix codegen by generating MPSFunctions headers

Pull Request resolved: #80195
Approved by: https://github.com/albanD

* [MPS] add `aten::normal.Tensor_float` `aten::normal.float_Tensor` `aten::normal.Tensor_Tensor` (#80297)

Fixes #ISSUE_NUMBER

Pull Request resolved: #80297
Approved by: https://github.com/albanD, https://github.com/kulinseth

* [MPS] Add flip (#80214)

Fixes #ISSUE_NUMBER

Pull Request resolved: #80214
Approved by: https://github.com/DenisVieriu97, https://github.com/albanD

* [MPS] Add logical ops (#80216)

This PR adds `logical_not`, `logical_and`, `logical_or`, `logical_xor`.
Pull Request resolved: #80216
Approved by: https://github.com/albanD, https://github.com/kulinseth

* [MPS] Add glu (#79866)

Adds mps op for `aten::glu.out`.

Pull Request resolved: #79866
Approved by: https://github.com/kulinseth, https://github.com/albanD

* [MPS] Fix std/var cache issue (#80502)

Use `getTensorsStringKey` which has tensor shape info added as part of the key to prevent cache lookup issue when the shape of input tensor is changed.

Fixes #80499

Pull Request resolved: #80502
Approved by: https://github.com/malfet, https://github.com/kulinseth

* Add scatter support for view operations (#79939)

* Add scatter support for view operations; #78074, #78886, #79672
* Update test_slicing_replace_column to properly test different sizes
* Handle in-place changes for binary ops; add new testcase
* Add new view ops testing scatter; add MPSDebugConfig.h config file for debugging purposes
* Merge gatherViewTensor and scatterViewTensor into a generic function
* Add scatter on demand in scatterViewOperation instead of caching it into a generic graph
* Create separate graphs for scatter and gather;
* Create scatter graph at scatter time

Fixes #ISSUE_NUMBER

Pull Request resolved: #79939
Approved by: https://github.com/razarmehr

* MPS: Fix handling of 1D tensors in linear backward (#80759)

Fixes ##79784

Pull Request resolved: #80759
Approved by: https://github.com/ezyang

* [MPS] Move the View ops to a separate file and reduce the number of graphs created (#80491)

This is dependent on the PR to go in first: #79939

Remove the data_ptr from the View Graph key which reduces the number of
graphs created significantly.

Don't wait when copying from MPS to MPS tensors

Pull Request resolved: #80491
Approved by: https://github.com/malfet

* [MPS] Add softplus backward (#79873)

Fixes #ISSUE_NUMBER

Pull Request resolved: #79873
Approved by: https://github.com/malfet

* [MPS] Add argmin (#80828)

This PR

1. adds argmin
2. refactors `reduction_type` in `ReduceOps.mm` with enum.

Co-authored by Kulin Seth <kulinseth@gmail.com>
Pull Request resolved: #80828
Approved by: https://github.com/malfet

* [MPS] Fix LSTM batch_first output transposed (#80597)

The output of LSTM with `batch_first` should be transposed back to batch first format.

Fixes #80306

Pull Request resolved: #80597
Approved by: https://github.com/kulinseth

* [MPS][BE] Introduce MPSUnaryCachedGraph (#81033)

I.e. CachedGraph that has input and output tensors
Also, add `MPSGraphCache::LookUpAs` template, which combines LookUp with
static_cast to target type

Pull Request resolved: #81033
Approved by: https://github.com/kulinseth

* [MPS] Add test consistency from OpInfo based tests from PR 78504 (#79532)

Pull Request resolved: #79532
Approved by: https://github.com/albanD, https://github.com/malfet

* [MPS] Add huber loss (#80163)

Fixes #ISSUE_NUMBER

Pull Request resolved: #80163
Approved by: https://github.com/kulinseth, https://github.com/malfet

* Remove two tests dependent on the MPS serialization checkin.

* Fix lint error (FLAKE8) F401

* Remove the serialization test from test_mps as its support is not there in 1.12.1.

Co-authored-by: Kulin Seth <kulinseth@gmail.com>
Co-authored-by: Nikita Shulga <nikita.shulga@gmail.com>
Co-authored-by: Kulin Seth <kulin_seth@apple.com>
Co-authored-by: Abhishek Pathak <abhipathak97@gmail.com>
Co-authored-by: Nikita Shulga <nshulga@fb.com>
Co-authored-by: qqaatw <qqaatw@gmail.com>
Co-authored-by: Ramin Azarmehr <razarmehr@apple.com>
@domschl
Copy link
Contributor

domschl commented Sep 11, 2022

LSTM backward path is still broken. (torch 1.12.1 and nightly of 2022-09-11)

As @PHRABAL noted, there is most probably a single permute() statement missing that permutes temporal and lstm hidden dimensions. This most probably also fixes the numerical issues.

@AngledLuffa
Copy link

This is still broken, and it's hugely problematic for my use case, since our library has bidirectional LSTMs everywhere

import torch

torch.manual_seed(1234)
lstm = torch.nn.LSTM(5, 5, num_layers=2, bidirectional=True, batch_first=True, dropout=0.5)
inp = torch.randn(1, 4, 5)
print(lstm(inp)[0])
print(lstm.to("mps")(inp.to("mps"))[0])

This is using a nightly build from 2022-12-06

pytorchmergebot pushed a commit that referenced this issue Mar 10, 2023
The native implementation of LSTM has been fixed on macOS 13.

On macOS 12, the multi-layer LSTM still has a numerical correctness issue that cannot be resolved on OS's side.

Thus, we fall back the multi-layer LSTM on macOS 12 to LSTMCell iteration. It might have performance impact but will make LSTM on macOS 12 fully usable.

Fixes: #90421
Issues related: #80306, #83144

Pull Request resolved: #90909
Approved by: https://github.com/albanD, https://github.com/kulinseth
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants