-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MPS] Fix LSTM backward and forward pass #95137
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95137
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 FailuresAs of commit a7bfd09: NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
By the way, is there more detailed documentation of MPSGraph methods? Apple documentation has no descriptions. @kulinseth, what docs are you using? Or is it one of the challenges of MPS development? 😅 |
Also, this will always fail as Lines 8914 to 8936 in f89ae0a
|
test/test_mps.py
Outdated
@unittest.skipIf(True, "Backward of lstm returns wrong result") | ||
def test_lstm_2(self, device="mps", dtype=torch.float32): | ||
def test_lstm_backward_one_layer(self, device="mps", dtype=torch.float32): | ||
layers = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For future test with several layers. Now, it passes.
But fails with two layers
Tensor output_out = at::empty_like(input); | ||
Tensor grad_state_out = at::empty_like(hx[0]); | ||
Tensor grad_cell_state_out = at::empty_like(hx[1]); | ||
|
||
|
||
std::vector<Tensor> grad_hx = {grad_state_out, grad_cell_state_out}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically, output binding before was completely broken. Gradients wrt input were zero in the best case. Garbage values were appearing frequently
gradRecWeightsPlaceholder = Placeholder([gradRecWeightsArray objectAtIndex:num_layers - i - 1], grad_rec_weights); | ||
gradWeightsPlaceholder = Placeholder([gradWeightsArray objectAtIndex:num_layers - i - 1], grad_weights); | ||
gradBiasPlaceholder = Placeholder([gradBiasArray objectAtIndex:num_layers - i - 1], grad_bias); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Notice indices! Elements are stored in the backward order cause they are pushed backwards
test/test_mps.py
Outdated
hx = torch.zeros(2, 3, 4, device="cpu") | ||
cx = torch.zeros(2, 3, 4, device="cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test fails when initialized randomly. But I sometimes cannot reproduce the failure. Added the change to PR. Let's see how this runs on CI
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fails on CI too
weights.push_back(grad_bias); | ||
weights.push_back(grad_bias); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weird
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I see why it is needed ;)
[results setObject:gradStatePlaceholder.getMPSGraphTensorData() forKey:gradStatePlaceholder.getMPSGraphTensor()]; | ||
[results setObject:gradCellStatePlaceholder.getMPSGraphTensorData() forKey:gradCellStatePlaceholder.getMPSGraphTensor()]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's an error with states gradient bindings
MPSGraphTensor* gradState = cachedGraph->gradState_; | ||
MPSGraphTensor* gradCellState = cachedGraph->gradCellState_; | ||
|
||
Placeholder gradStatePlaceholder = Placeholder(gradState, grad_state_out); | ||
Placeholder gradCellStatePlaceholder = Placeholder(gradCellState, grad_cell_state_out); | ||
[results setObject:gradStatePlaceholder.getMPSGraphTensorData() forKey:gradStatePlaceholder.getMPSGraphTensor()]; | ||
[results setObject:gradCellStatePlaceholder.getMPSGraphTensorData() forKey:gradCellStatePlaceholder.getMPSGraphTensor()]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Crucial part for states gradient calculation. It was missing
Quite a mistake:
They all are using the same input |
@kulinseth, at this point, one-layer LSTM works as expected and passes asserts with CPU comparison. But multilayered version is not consistent as it incorrectly tries to work with only one input:
But backward call must have information about outputs of all layers to correctly calculate gradients |
Excited to anounce that LSTM is fully consistent with CPU implementation! Gradients of all parameters and outputs are asserted |
@albanD , can you please take a look at these changes. MPS side changes are fine. |
@pytorchbot merge -f "MPS tests are green." |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: Approval needed from one of the following: |
@pytorchbot merge -f "MPS tests are green." |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes pytorch#91694 Fixes pytorch#92615 Several transpositions were missing for backward graph in case of `batch_first=True`. The pytorch#91694 is not reproduced with `batch_first=False`. After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to pytorch#92615. After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded Funny enough, backward tests were completely disabled before and were not passing: ```python @unittest.skipIf(True, "Backward of lstm returns wrong result") def test_lstm_2(self, device="mps", dtype=torch.float32): ``` UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too Pull Request resolved: pytorch#95137 Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer
@soulitzer @albanD this test is introducing backwards incompat changes Is this expected/safe? The failure (logs):
|
Fixes pytorch#91694 Fixes pytorch#92615 Several transpositions were missing for backward graph in case of `batch_first=True`. The pytorch#91694 is not reproduced with `batch_first=False`. After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to pytorch#92615. After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded Funny enough, backward tests were completely disabled before and were not passing: ```python @unittest.skipIf(True, "Backward of lstm returns wrong result") def test_lstm_2(self, device="mps", dtype=torch.float32): ``` UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too Pull Request resolved: pytorch#95137 Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer
Fixes #91694 Fixes #92615 Several transpositions were missing for backward graph in case of `batch_first=True`. The #91694 is not reproduced with `batch_first=False`. After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to #92615. After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded Funny enough, backward tests were completely disabled before and were not passing: ```python @unittest.skipIf(True, "Backward of lstm returns wrong result") def test_lstm_2(self, device="mps", dtype=torch.float32): ``` UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too Pull Request resolved: pytorch/pytorch#95137 Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer
Fixes #91694 Fixes #92615 Several transpositions were missing for backward graph in case of `batch_first=True`. The #91694 is not reproduced with `batch_first=False`. After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to #92615. After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded Funny enough, backward tests were completely disabled before and were not passing: ```python @unittest.skipIf(True, "Backward of lstm returns wrong result") def test_lstm_2(self, device="mps", dtype=torch.float32): ``` UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too Pull Request resolved: pytorch/pytorch#95137 Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer
* [MPS] Fix LSTM backward and forward pass (#95137) Fixes #91694 Fixes #92615 Several transpositions were missing for backward graph in case of `batch_first=True`. The #91694 is not reproduced with `batch_first=False`. After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to #92615. After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded Funny enough, backward tests were completely disabled before and were not passing: ```python @unittest.skipIf(True, "Backward of lstm returns wrong result") def test_lstm_2(self, device="mps", dtype=torch.float32): ``` UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too Pull Request resolved: #95137 Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer * Update the allowlist for lstm_mps_backward * More update to the BC allowlist --------- Co-authored-by: alexdremov <dremov.me@gmail.com> Co-authored-by: albanD <desmaison.alban@gmail.com>
@ZainRizvi Yup, this is expected. LSTM seemed to be silently incorrect on MPS previously, so this should be considered a bug fix. |
@soulitzer , @albanD , added a fix in the release branch for this. I can cherry-pick it to master. |
Fixes #91694 Fixes #92615 Several transpositions were missing for backward graph in case of `batch_first=True`. The #91694 is not reproduced with `batch_first=False`. After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to #92615. After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded Funny enough, backward tests were completely disabled before and were not passing: ```python @unittest.skipIf(True, "Backward of lstm returns wrong result") def test_lstm_2(self, device="mps", dtype=torch.float32): ``` UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too Pull Request resolved: pytorch/pytorch#95137 Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer
This reverts commit b9e9515.
* [MPS] Fix LSTM backward and forward pass (pytorch#95137) Fixes pytorch#91694 Fixes pytorch#92615 Several transpositions were missing for backward graph in case of `batch_first=True`. The pytorch#91694 is not reproduced with `batch_first=False`. After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to pytorch#92615. After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded Funny enough, backward tests were completely disabled before and were not passing: ```python @unittest.skipIf(True, "Backward of lstm returns wrong result") def test_lstm_2(self, device="mps", dtype=torch.float32): ``` UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too Pull Request resolved: pytorch#95137 Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer * Update the allowlist for lstm_mps_backward * More update to the BC allowlist --------- Co-authored-by: alexdremov <dremov.me@gmail.com> Co-authored-by: albanD <desmaison.alban@gmail.com>
Fixes pytorch#91694 Fixes pytorch#92615 Several transpositions were missing for backward graph in case of `batch_first=True`. The pytorch#91694 is not reproduced with `batch_first=False`. After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to pytorch#92615. After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded Funny enough, backward tests were completely disabled before and were not passing: ```python @unittest.skipIf(True, "Backward of lstm returns wrong result") def test_lstm_2(self, device="mps", dtype=torch.float32): ``` UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too Pull Request resolved: pytorch#95137 Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer
Fixes #91694
Fixes #92615
Several transpositions were missing for backward graph in case of
batch_first=True
. The #91694 is not reproduced withbatch_first=False
.After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to #92615.
After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded
Funny enough, backward tests were completely disabled before and were not passing:
UPD: forward pass of multi-layer version also was wrong due to the incorrect
initState, initCell
slices. Tests were passing because states were inited with zeros. Accidentally fixed this too