New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding support for CuDNN-based LSTM with projections #47725
Conversation
Thanks @BowenBao for helping with onnx questions! I have changed the asserts and also added error message for the onnx export as LSTMs with projections are not supported there (will require modifying onnxruntime code to enable this support). I think now only backward compatibility question remains open. @ngimel, @zou3519, please let me know which option you prefer here. If you are ok with breaking backward compatibility in this case, please let me know how to "add the operators to the allow-list". |
Let's try breaking bc, option 2. To make bc tests pass, you should add the functions you are breaking to |
Importing to see what internal CI says. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Bazel build looks unrelated, probably should go away with rebase. ROCm build should be fixed - you should modify projection tests to expect to raise RuntimeError on ROCm, and adjust tolerance for the failing fp16 test. |
@ngimel, what would be the right way to test if ROCm is used? I tried to add "TEST_WITH_ROCM" if, but it seems that in some cases it's still using cudnn? Is it possible to check it on the model level itself somehow, e.g. like I can distinguish between cpu and cuda by checking .device() property? |
@ngimel, @zou3519, I fixed ROCm tests, no everything passes, except for some test case that seems to be unrelated to my PR (at least it was not failing on the previous test runs). Is there anything else that I need to do to complete the merge? Should I rebase with master one more time to check if the failing test will be fixed? |
asan failure is unrelated. Let me try importing, thank you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Somehow `mypy torch/quantization` got broken in the past couple of days: https://gist.github.com/vkuzo/07af454246f0a68e6fa8929beeec7e0d . I didn't see any relevant PRs other than #47725, which doesn't seem related. The error doesn't seem real, as the arguments to `_cudnn_rnn_flatten_weight` seem correct. For now, ignoring the failure so we have a clean `mypy` run on `torch/quantization`. Test Plan: ``` mypy torch/quantization ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Somehow `mypy torch/quantization` got broken in the past couple of days: https://gist.github.com/vkuzo/07af454246f0a68e6fa8929beeec7e0d . I didn't see any relevant PRs other than #47725, which doesn't seem related. The error doesn't seem real, as the arguments to `_cudnn_rnn_flatten_weight` seem correct. For now, ignoring the failure so we have a clean `mypy` run on `torch/quantization`. Test Plan: ``` mypy torch/quantization ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: cd1c4213022d7efb71946ca5325b1126845a014e Pull Request resolved: #49549
Summary: Somehow `mypy torch/quantization` got broken in the past couple of days: https://gist.github.com/vkuzo/07af454246f0a68e6fa8929beeec7e0d . I didn't see any relevant PRs other than #47725, which doesn't seem related. The error doesn't seem real, as the arguments to `_cudnn_rnn_flatten_weight` seem correct. For now, ignoring the failure so we have a clean `mypy` run on `torch/quantization`. Test Plan: ``` mypy torch/quantization ``` Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D25616972](https://our.internmc.facebook.com/intern/diff/D25616972) [ghstack-poisoned]
Summary: Somehow `mypy torch/quantization` got broken in the past couple of days: https://gist.github.com/vkuzo/07af454246f0a68e6fa8929beeec7e0d . I didn't see any relevant PRs other than #47725, which doesn't seem related. The error doesn't seem real, as the arguments to `_cudnn_rnn_flatten_weight` seem correct. For now, ignoring the failure so we have a clean `mypy` run on `torch/quantization`. Test Plan: ``` mypy torch/quantization ``` Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D25616972](https://our.internmc.facebook.com/intern/diff/D25616972) [ghstack-poisoned]
Summary: Pull Request resolved: #49549 Somehow `mypy torch/quantization` got broken in the past couple of days: https://gist.github.com/vkuzo/07af454246f0a68e6fa8929beeec7e0d . I didn't see any relevant PRs other than #47725, which doesn't seem related. The error doesn't seem real, as the arguments to `_cudnn_rnn_flatten_weight` seem correct. For now, ignoring the failure so we have a clean `mypy` run on `torch/quantization`. Test Plan: ``` mypy torch/quantization ``` Imported from OSS Reviewed By: jerryzh168 Differential Revision: D25616972 fbshipit-source-id: 46c207fe1565ec949c0b1f57d6cd0c93f627e6bd
Summary: Fixes pytorch#46213 I didn't yet update the documentation, will add those change soon. A few other things that I didn't do, but want to clarify if I maybe should. 1. I didn't expose projections in c++ API: torch/csrc/api/src/nn/modules/rnn.cpp. Let me know if this is desirable and I will add those changes. 2. I didn't expose projections in "lstm_cell" function and "_thnn_differentiable_lstm_cell_backward" functions from aten/src/ATen/native/RNN.cpp. As far as I understand, they are not needed for nn.LSTM CPU execution. For lstm_cell, projections don't bring any real benefit, since if cell is used separately, it can be easily added in Python. For "_thnn_differentiable_lstm_cell_backward", I'm actually not sure where exactly that function is used, so I also disabled projections there for now. Please let me know if I should change that. 3. I added check that projections are not supported for quantized LSTMs to quantized_lstm_<data/input> functions. But I didn't add any checks to LSTMCell code. It seems that since I disabled projections in "lstm_cell" function, they should also not be available for quantized models through any other API than quantized_lstm_<data/input>. Please let me know if I'm not correct and I will add checks to other places. 4. Projections are not supported for CuDNN versions < 7.1.2. Should I add the check for CuDNN version and disable projections in that case? If so, what will be the best way to do that? 5. Currently I added projection weight as the last weight, so the layout is "w_ih, w_hh, b_ih, b_hh, w_hr". This breaks the assumption that biases come after weights and thus I had to add additional if-s in various places. Alternative way would be to have "w_ih, w_hh, w_hr, b_ih, b_hh" layout, in which case the assumption will be true. But in that case I will need to split the loop in get_parameters function from aten/src/ATen/native/cudnn/RNN.cpp. And in some cases, I will still need to add an "undefined" tensor in the 3rd position, because we get all 5 weights from CuDNN most of the time. So I'm not sure which way is better. Let me know if you think I should change to the weights-then-biases layout. Pull Request resolved: pytorch#47725 Reviewed By: zou3519 Differential Revision: D25449794 Pulled By: ngimel fbshipit-source-id: fe6ce59e481d1f5fd861a8ff7fa13d1affcedb0c
Summary: Pull Request resolved: pytorch#49549 Somehow `mypy torch/quantization` got broken in the past couple of days: https://gist.github.com/vkuzo/07af454246f0a68e6fa8929beeec7e0d . I didn't see any relevant PRs other than pytorch#47725, which doesn't seem related. The error doesn't seem real, as the arguments to `_cudnn_rnn_flatten_weight` seem correct. For now, ignoring the failure so we have a clean `mypy` run on `torch/quantization`. Test Plan: ``` mypy torch/quantization ``` Imported from OSS Reviewed By: jerryzh168 Differential Revision: D25616972 fbshipit-source-id: 46c207fe1565ec949c0b1f57d6cd0c93f627e6bd
Fixes #46213
I didn't yet update the documentation, will add those change soon. A few other things that I didn't do, but want to clarify if I maybe should.