Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for CuDNN-based LSTM with projections #47725

Closed
wants to merge 61 commits into from

Conversation

Kipok
Copy link

@Kipok Kipok commented Nov 11, 2020

Fixes #46213

I didn't yet update the documentation, will add those change soon. A few other things that I didn't do, but want to clarify if I maybe should.

  1. I didn't expose projections in c++ API: torch/csrc/api/src/nn/modules/rnn.cpp. Let me know if this is desirable and I will add those changes.
  2. I didn't expose projections in "lstm_cell" function and "_thnn_differentiable_lstm_cell_backward" functions from aten/src/ATen/native/RNN.cpp. As far as I understand, they are not needed for nn.LSTM CPU execution. For lstm_cell, projections don't bring any real benefit, since if cell is used separately, it can be easily added in Python. For "_thnn_differentiable_lstm_cell_backward", I'm actually not sure where exactly that function is used, so I also disabled projections there for now. Please let me know if I should change that.
  3. I added check that projections are not supported for quantized LSTMs to quantized_lstm_<data/input> functions. But I didn't add any checks to LSTMCell code. It seems that since I disabled projections in "lstm_cell" function, they should also not be available for quantized models through any other API than quantized_lstm_<data/input>. Please let me know if I'm not correct and I will add checks to other places.
  4. Projections are not supported for CuDNN versions < 7.1.2. Should I add the check for CuDNN version and disable projections in that case? If so, what will be the best way to do that?
  5. Currently I added projection weight as the last weight, so the layout is "w_ih, w_hh, b_ih, b_hh, w_hr". This breaks the assumption that biases come after weights and thus I had to add additional if-s in various places. Alternative way would be to have "w_ih, w_hh, w_hr, b_ih, b_hh" layout, in which case the assumption will be true. But in that case I will need to split the loop in get_parameters function from aten/src/ATen/native/cudnn/RNN.cpp. And in some cases, I will still need to add an "undefined" tensor in the 3rd position, because we get all 5 weights from CuDNN most of the time. So I'm not sure which way is better. Let me know if you think I should change to the weights-then-biases layout.

@Kipok
Copy link
Author

Kipok commented Dec 9, 2020

Thanks @BowenBao for helping with onnx questions! I have changed the asserts and also added error message for the onnx export as LSTMs with projections are not supported there (will require modifying onnxruntime code to enable this support).

I think now only backward compatibility question remains open. @ngimel, @zou3519, please let me know which option you prefer here. If you are ok with breaking backward compatibility in this case, please let me know how to "add the operators to the allow-list".

@ngimel
Copy link
Collaborator

ngimel commented Dec 9, 2020

Let's try breaking bc, option 2. To make bc tests pass, you should add the functions you are breaking to allow_list in test/backward_compatibility/check_backward_compatibility.py with some future date (say, couple weeks ahead), look at the examples of functions already there.

@ngimel
Copy link
Collaborator

ngimel commented Dec 10, 2020

Importing to see what internal CI says.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@Kipok
Copy link
Author

Kipok commented Dec 10, 2020

@ngimel, @zou3519, both onnx and bc tests are passing now, but there are still 2 failures related to some specific builds. Any suggestions on what needs to be done to fix those? Or is it ok that they fail?

@ngimel
Copy link
Collaborator

ngimel commented Dec 10, 2020

Bazel build looks unrelated, probably should go away with rebase. ROCm build should be fixed - you should modify projection tests to expect to raise RuntimeError on ROCm, and adjust tolerance for the failing fp16 test.

@Kipok
Copy link
Author

Kipok commented Dec 11, 2020

@ngimel, what would be the right way to test if ROCm is used? I tried to add "TEST_WITH_ROCM" if, but it seems that in some cases it's still using cudnn? Is it possible to check it on the model level itself somehow, e.g. like I can distinguish between cpu and cuda by checking .device() property?

@Kipok
Copy link
Author

Kipok commented Dec 15, 2020

@ngimel, @zou3519, I fixed ROCm tests, no everything passes, except for some test case that seems to be unrelated to my PR (at least it was not failing on the previous test runs). Is there anything else that I need to do to complete the merge? Should I rebase with master one more time to check if the failing test will be fixed?

@ngimel
Copy link
Collaborator

ngimel commented Dec 15, 2020

asan failure is unrelated. Let me try importing, thank you!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in 1b6d18a.

vkuzo added a commit that referenced this pull request Dec 17, 2020
Summary:

Somehow `mypy torch/quantization` got broken in the past couple of days:
https://gist.github.com/vkuzo/07af454246f0a68e6fa8929beeec7e0d
.  I didn't see any relevant PRs other than
#47725, which doesn't seem
related. The error doesn't seem real, as the arguments to
`_cudnn_rnn_flatten_weight` seem correct. For now,
ignoring the failure so we have a clean `mypy` run on
`torch/quantization`.

Test Plan:

```
mypy torch/quantization
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Dec 17, 2020
Summary:

Somehow `mypy torch/quantization` got broken in the past couple of days:
https://gist.github.com/vkuzo/07af454246f0a68e6fa8929beeec7e0d
.  I didn't see any relevant PRs other than
#47725, which doesn't seem
related. The error doesn't seem real, as the arguments to
`_cudnn_rnn_flatten_weight` seem correct. For now,
ignoring the failure so we have a clean `mypy` run on
`torch/quantization`.

Test Plan:

```
mypy torch/quantization
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: cd1c4213022d7efb71946ca5325b1126845a014e
Pull Request resolved: #49549
vkuzo added a commit that referenced this pull request Dec 18, 2020
Summary:

Somehow `mypy torch/quantization` got broken in the past couple of days:
https://gist.github.com/vkuzo/07af454246f0a68e6fa8929beeec7e0d
.  I didn't see any relevant PRs other than
#47725, which doesn't seem
related. The error doesn't seem real, as the arguments to
`_cudnn_rnn_flatten_weight` seem correct. For now,
ignoring the failure so we have a clean `mypy` run on
`torch/quantization`.

Test Plan:

```
mypy torch/quantization
```

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D25616972](https://our.internmc.facebook.com/intern/diff/D25616972)

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Dec 21, 2020
Summary:

Somehow `mypy torch/quantization` got broken in the past couple of days:
https://gist.github.com/vkuzo/07af454246f0a68e6fa8929beeec7e0d
.  I didn't see any relevant PRs other than
#47725, which doesn't seem
related. The error doesn't seem real, as the arguments to
`_cudnn_rnn_flatten_weight` seem correct. For now,
ignoring the failure so we have a clean `mypy` run on
`torch/quantization`.

Test Plan:

```
mypy torch/quantization
```

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D25616972](https://our.internmc.facebook.com/intern/diff/D25616972)

[ghstack-poisoned]
facebook-github-bot pushed a commit that referenced this pull request Dec 22, 2020
Summary:
Pull Request resolved: #49549

Somehow `mypy torch/quantization` got broken in the past couple of days:
https://gist.github.com/vkuzo/07af454246f0a68e6fa8929beeec7e0d
.  I didn't see any relevant PRs other than
#47725, which doesn't seem
related. The error doesn't seem real, as the arguments to
`_cudnn_rnn_flatten_weight` seem correct. For now,
ignoring the failure so we have a clean `mypy` run on
`torch/quantization`.

Test Plan:
```
mypy torch/quantization
```

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D25616972

fbshipit-source-id: 46c207fe1565ec949c0b1f57d6cd0c93f627e6bd
hwangdeyu pushed a commit to hwangdeyu/pytorch that referenced this pull request Jan 6, 2021
Summary:
Fixes pytorch#46213

I didn't yet update the documentation, will add those change soon. A few other things that I didn't do, but want to clarify if I maybe should.

1. I didn't expose projections in c++ API: torch/csrc/api/src/nn/modules/rnn.cpp. Let me know if this is desirable and I will add those changes.
2. I didn't expose projections in "lstm_cell" function and "_thnn_differentiable_lstm_cell_backward" functions from aten/src/ATen/native/RNN.cpp. As far as I understand, they are not needed for nn.LSTM CPU execution. For lstm_cell, projections don't bring any real benefit, since if cell is used separately, it can be easily added in Python. For "_thnn_differentiable_lstm_cell_backward", I'm actually not sure where exactly that function is used, so I also disabled projections there for now. Please let me know if I should change that.
3. I added check that projections are not supported for quantized LSTMs to quantized_lstm_<data/input> functions. But I didn't add any checks to LSTMCell code. It seems that since I disabled projections in "lstm_cell" function, they should also not be available for quantized models through any other API than quantized_lstm_<data/input>. Please let me know if I'm not correct and I will add checks to other places.
4. Projections are not supported for CuDNN versions < 7.1.2. Should I add the check for CuDNN version and disable projections in that case? If so, what will be the best way to do that?
5. Currently I added projection weight as the last weight, so the layout is "w_ih, w_hh, b_ih, b_hh, w_hr". This breaks the assumption that biases come after weights and thus I had to add additional if-s in various places. Alternative way would be to have "w_ih, w_hh, w_hr, b_ih, b_hh" layout, in which case the assumption will be true. But in that case I will need to split the loop in get_parameters function from aten/src/ATen/native/cudnn/RNN.cpp. And in some cases, I will still need to add an "undefined" tensor in the 3rd position, because we get all 5 weights from CuDNN most of the time. So I'm not sure which way is better. Let me know if you think I should change to the weights-then-biases layout.

Pull Request resolved: pytorch#47725

Reviewed By: zou3519

Differential Revision: D25449794

Pulled By: ngimel

fbshipit-source-id: fe6ce59e481d1f5fd861a8ff7fa13d1affcedb0c
hwangdeyu pushed a commit to hwangdeyu/pytorch that referenced this pull request Jan 6, 2021
Summary:
Pull Request resolved: pytorch#49549

Somehow `mypy torch/quantization` got broken in the past couple of days:
https://gist.github.com/vkuzo/07af454246f0a68e6fa8929beeec7e0d
.  I didn't see any relevant PRs other than
pytorch#47725, which doesn't seem
related. The error doesn't seem real, as the arguments to
`_cudnn_rnn_flatten_weight` seem correct. For now,
ignoring the failure so we have a clean `mypy` run on
`torch/quantization`.

Test Plan:
```
mypy torch/quantization
```

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D25616972

fbshipit-source-id: 46c207fe1565ec949c0b1f57d6cd0c93f627e6bd
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Integrating CuDNN API for LSTMs with projections
10 participants