Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] Fix LSTM backward and forward pass #95137

Closed
wants to merge 17 commits into from

Conversation

alexdremov
Copy link
Contributor

@alexdremov alexdremov commented Feb 19, 2023

Fixes #91694
Fixes #92615

Several transpositions were missing for backward graph in case of batch_first=True. The #91694 is not reproduced with batch_first=False.

After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to #92615.

After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded

Funny enough, backward tests were completely disabled before and were not passing:

    @unittest.skipIf(True, "Backward of lstm returns wrong result")
    def test_lstm_2(self, device="mps", dtype=torch.float32):

UPD: forward pass of multi-layer version also was wrong due to the incorrect initState, initCell slices. Tests were passing because states were inited with zeros. Accidentally fixed this too

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 19, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95137

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 Failures

As of commit a7bfd09:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Feb 19, 2023
@alexdremov
Copy link
Contributor Author

alexdremov commented Feb 19, 2023

By the way, is there more detailed documentation of MPSGraph methods? Apple documentation has no descriptions. @kulinseth, what docs are you using? Or is it one of the challenges of MPS development? 😅

@alexdremov
Copy link
Contributor Author

Also, this will always fail as inp is randomly generated each time

pytorch/test/test_mps.py

Lines 8914 to 8936 in f89ae0a

@unittest.skipIf(True, "Backward of lstm returns wrong result")
def test_lstm_2(self, device="mps", dtype=torch.float32):
def get_results(device):
rnn = nn.LSTM(1, 4, 1, device=device)
inp = torch.randn(2, 3, 1, device=device, requires_grad=True)
hx = torch.zeros(1, 3, 4, device=device)
cx = torch.zeros(1, 3, 4, device=device)
output, _ = rnn(inp, (hx, cx))
output.sum().backward()
weight_grad = rnn.weight_ih_l0.grad.clone()
input_grad = inp.grad.clone()
return output, weight_grad, input_grad
cpu_output, cpu_weight_grad, cpu_input_grad = get_results("cpu")
mps_output, mps_weight_grad, mps_input_grad = get_results("mps")
self.assertEqual(cpu_output, mps_output)
self.assertEqual(cpu_input_grad, mps_input_grad)
self.assertEqual(cpu_weight_grad, mps_weight_grad)

@alexdremov alexdremov changed the title [MPS] LSTM batch_first=True fix [MPS] Fix LSTM backward pass Feb 19, 2023
test/test_mps.py Outdated
@unittest.skipIf(True, "Backward of lstm returns wrong result")
def test_lstm_2(self, device="mps", dtype=torch.float32):
def test_lstm_backward_one_layer(self, device="mps", dtype=torch.float32):
layers = 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future test with several layers. Now, it passes.

But fails with two layers

Comment on lines +507 to +512
Tensor output_out = at::empty_like(input);
Tensor grad_state_out = at::empty_like(hx[0]);
Tensor grad_cell_state_out = at::empty_like(hx[1]);


std::vector<Tensor> grad_hx = {grad_state_out, grad_cell_state_out};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically, output binding before was completely broken. Gradients wrt input were zero in the best case. Garbage values were appearing frequently

Comment on lines 549 to 551
gradRecWeightsPlaceholder = Placeholder([gradRecWeightsArray objectAtIndex:num_layers - i - 1], grad_rec_weights);
gradWeightsPlaceholder = Placeholder([gradWeightsArray objectAtIndex:num_layers - i - 1], grad_weights);
gradBiasPlaceholder = Placeholder([gradBiasArray objectAtIndex:num_layers - i - 1], grad_bias);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notice indices! Elements are stored in the backward order cause they are pushed backwards

test/test_mps.py Outdated
Comment on lines 8882 to 8883
hx = torch.zeros(2, 3, 4, device="cpu")
cx = torch.zeros(2, 3, 4, device="cpu")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test fails when initialized randomly. But I sometimes cannot reproduce the failure. Added the change to PR. Let's see how this runs on CI

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fails on CI too

Comment on lines 536 to 537
weights.push_back(grad_bias);
weights.push_back(grad_bias);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weird

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I see why it is needed ;)

Comment on lines -557 to -558
[results setObject:gradStatePlaceholder.getMPSGraphTensorData() forKey:gradStatePlaceholder.getMPSGraphTensor()];
[results setObject:gradCellStatePlaceholder.getMPSGraphTensorData() forKey:gradCellStatePlaceholder.getMPSGraphTensor()];
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an error with states gradient bindings

Comment on lines 509 to 515
MPSGraphTensor* gradState = cachedGraph->gradState_;
MPSGraphTensor* gradCellState = cachedGraph->gradCellState_;

Placeholder gradStatePlaceholder = Placeholder(gradState, grad_state_out);
Placeholder gradCellStatePlaceholder = Placeholder(gradCellState, grad_cell_state_out);
[results setObject:gradStatePlaceholder.getMPSGraphTensorData() forKey:gradStatePlaceholder.getMPSGraphTensor()];
[results setObject:gradCellStatePlaceholder.getMPSGraphTensorData() forKey:gradCellStatePlaceholder.getMPSGraphTensor()];
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Crucial part for states gradient calculation. It was missing

@alexdremov
Copy link
Contributor Author

Quite a mistake:

outputs = [mpsGraph LSTMGradientsWithSourceTensor: inputTensor

They all are using the same input

@alexdremov
Copy link
Contributor Author

@kulinseth, at this point, one-layer LSTM works as expected and passes asserts with CPU comparison. But multilayered version is not consistent as it incorrectly tries to work with only one input:

Quite a mistake:

outputs = [mpsGraph LSTMGradientsWithSourceTensor: inputTensor

They all are using the same input

But backward call must have information about outputs of all layers to correctly calculate gradients

@alexdremov
Copy link
Contributor Author

alexdremov commented Feb 20, 2023

Excited to anounce that LSTM is fully consistent with CPU implementation! Gradients of all parameters and outputs are asserted
🚀🚀🚀

@kulinseth
Copy link
Collaborator

@albanD , can you please take a look at these changes. MPS side changes are fine.

test/test_mps.py Outdated Show resolved Hide resolved
test/test_mps.py Outdated Show resolved Hide resolved
@kulinseth
Copy link
Collaborator

@pytorchbot merge -f "MPS tests are green."

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Approval needed from one of the following:
kunalb, rohan-varma, ziky90, vtlam, PratsBhatt, ...

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@razarmehr
Copy link
Collaborator

@pytorchbot merge -f "MPS tests are green."

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Feb 23, 2023
Fixes pytorch#91694
Fixes pytorch#92615

Several transpositions were missing for backward graph in case of `batch_first=True`. The pytorch#91694 is not reproduced with `batch_first=False`.

After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to pytorch#92615.

After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded

Funny enough, backward tests were completely disabled before and were not passing:
```python
    @unittest.skipIf(True, "Backward of lstm returns wrong result")
    def test_lstm_2(self, device="mps", dtype=torch.float32):
```

UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too

Pull Request resolved: pytorch#95137
Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer
This was referenced Feb 23, 2023
@ZainRizvi
Copy link
Contributor

@soulitzer @albanD this test is introducing backwards incompat changes

Is this expected/safe?

The failure (logs):

2023-02-23T18:27:03.8174922Z The PR is introducing backward incompatible changes to the operator library. Please contact PyTorch team to confirm whether this change is wanted or not. 
2023-02-23T18:27:03.8174946Z 
2023-02-23T18:27:03.8175023Z Broken ops: [
2023-02-23T18:27:03.8175616Z 	aten::lstm_mps_backward.out(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!)[] out1, Tensor(c!)[] out2) -> ()
2023-02-23T18:27:03.8176232Z 	aten::lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
2023-02-23T18:27:03.8176788Z 	aten::_lstm_mps.out(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))
2023-02-23T18:27:03.8177188Z 	aten::_lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
2023-02-23T18:27:03.8177250Z ]

kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Feb 24, 2023
Fixes pytorch#91694
Fixes pytorch#92615

Several transpositions were missing for backward graph in case of `batch_first=True`. The pytorch#91694 is not reproduced with `batch_first=False`.

After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to pytorch#92615.

After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded

Funny enough, backward tests were completely disabled before and were not passing:
```python
    @unittest.skipIf(True, "Backward of lstm returns wrong result")
    def test_lstm_2(self, device="mps", dtype=torch.float32):
```

UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too

Pull Request resolved: pytorch#95137
Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 25, 2023
Fixes #91694
Fixes #92615

Several transpositions were missing for backward graph in case of `batch_first=True`. The #91694 is not reproduced with `batch_first=False`.

After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to #92615.

After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded

Funny enough, backward tests were completely disabled before and were not passing:
```python
    @unittest.skipIf(True, "Backward of lstm returns wrong result")
    def test_lstm_2(self, device="mps", dtype=torch.float32):
```

UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too

Pull Request resolved: pytorch/pytorch#95137
Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 25, 2023
Fixes #91694
Fixes #92615

Several transpositions were missing for backward graph in case of `batch_first=True`. The #91694 is not reproduced with `batch_first=False`.

After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to #92615.

After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded

Funny enough, backward tests were completely disabled before and were not passing:
```python
    @unittest.skipIf(True, "Backward of lstm returns wrong result")
    def test_lstm_2(self, device="mps", dtype=torch.float32):
```

UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too

Pull Request resolved: pytorch/pytorch#95137
Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer
atalman pushed a commit that referenced this pull request Feb 25, 2023
* [MPS] Fix LSTM backward and forward pass (#95137)

Fixes #91694
Fixes #92615

Several transpositions were missing for backward graph in case of `batch_first=True`. The #91694 is not reproduced with `batch_first=False`.

After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to #92615.

After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded

Funny enough, backward tests were completely disabled before and were not passing:
```python
    @unittest.skipIf(True, "Backward of lstm returns wrong result")
    def test_lstm_2(self, device="mps", dtype=torch.float32):
```

UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too

Pull Request resolved: #95137
Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer

* Update the allowlist for lstm_mps_backward

* More update to the BC allowlist

---------

Co-authored-by: alexdremov <dremov.me@gmail.com>
Co-authored-by: albanD <desmaison.alban@gmail.com>
@soulitzer
Copy link
Contributor

@ZainRizvi Yup, this is expected. LSTM seemed to be silently incorrect on MPS previously, so this should be considered a bug fix.

@kulinseth
Copy link
Collaborator

@soulitzer , @albanD , added a fix in the release branch for this. I can cherry-pick it to master.

cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 5, 2023
Fixes #91694
Fixes #92615

Several transpositions were missing for backward graph in case of `batch_first=True`. The #91694 is not reproduced with `batch_first=False`.

After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to #92615.

After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded

Funny enough, backward tests were completely disabled before and were not passing:
```python
    @unittest.skipIf(True, "Backward of lstm returns wrong result")
    def test_lstm_2(self, device="mps", dtype=torch.float32):
```

UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too

Pull Request resolved: pytorch/pytorch#95137
Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer
pruthvistony added a commit to ROCm/pytorch that referenced this pull request May 2, 2023
pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request May 3, 2023
* [MPS] Fix LSTM backward and forward pass (pytorch#95137)

Fixes pytorch#91694
Fixes pytorch#92615

Several transpositions were missing for backward graph in case of `batch_first=True`. The pytorch#91694 is not reproduced with `batch_first=False`.

After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to pytorch#92615.

After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded

Funny enough, backward tests were completely disabled before and were not passing:
```python
    @unittest.skipIf(True, "Backward of lstm returns wrong result")
    def test_lstm_2(self, device="mps", dtype=torch.float32):
```

UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too

Pull Request resolved: pytorch#95137
Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer

* Update the allowlist for lstm_mps_backward

* More update to the BC allowlist

---------

Co-authored-by: alexdremov <dremov.me@gmail.com>
Co-authored-by: albanD <desmaison.alban@gmail.com>
jhavukainen pushed a commit to kulinseth/pytorch that referenced this pull request Mar 15, 2024
Fixes pytorch#91694
Fixes pytorch#92615

Several transpositions were missing for backward graph in case of `batch_first=True`. The pytorch#91694 is not reproduced with `batch_first=False`.

After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to pytorch#92615.

After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded

Funny enough, backward tests were completely disabled before and were not passing:
```python
    @unittest.skipIf(True, "Backward of lstm returns wrong result")
    def test_lstm_2(self, device="mps", dtype=torch.float32):
```

UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too

Pull Request resolved: pytorch#95137
Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/mps Run MPS tests (subset of trunk) Merged open source release notes: mps Release notes category
Projects
None yet
8 participants