-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding support for CuDNN-based LSTM with projections #47725
Closed
Closed
Changes from 44 commits
Commits
Show all changes
61 commits
Select commit
Hold shift + click to select a range
2aa911e
Expose proj_dim parameter
a6c987c
Fix int -> int_64t issue
2a5fdce
Start exposing proj_size throughout the code base
ee74637
Exposed through get_weight_buf and flatten_weights
c72c9ce
Expose proj_size through _cudnn_rnn
cd3acc2
Update get_parameters to work with projections
54a9849
Remove redundant proj_size, fix get_expected_data_ptrs
24e47c2
Fix try_get_weight_buf
dab693d
Fix _cudnn_rnn function
4d07a2a
Fix _cudnn_rnn_backward functions
5481d77
Add correct hx creation on python side
d0dc6db
Fix incorrect projection layers initialization
956ed26
Correct weight initialization on python side
4d0feae
Fix output size issue
134ad1d
Fix multi-layer projections issue
9557035
Expose proj_size in setstate and extra_repr
2da10ee
Fix AutocastRNN to accept models with projections
b63facc
Merge branch 'master' into cudnn_projections
b142a0a
Add test TODOs, add check for non-cudnn code
44d0c67
Fix incorrect hidden states init for LSTM
da32e77
Fix error for RNN/GRU of accessing undefined cx
2a2c3f9
Fix no-bias projections lstm for fp32
de5519d
Fix no-bias projection fp16 case
06fadd2
Add check for rnn/gru, add initial unit tests
09d0fe1
Add proj_size to test_variable_sequence
ae84ba6
Add projections to rnn_weight_norm test
8b0544d
Add projections to cudnn_weight_format test
db24db7
Add projections to cudnn_weight_tying test
515341c
Add projections to rnn_args_check test
337b0e7
Add projections to rnn_check_device test
ccb5054
Expose cudnn with projections on CPU
527f62c
Add cpu_vs_gpu projections test
4542031
Add weight norm test
28986a8
Remove TODOs
3bbd503
Code clean up
03e386e
Revert empty changes
336732f
Revert miopen style change
8842c13
Disable projections for quantized LSTMs
46f3172
Merge branch 'master' into cudnn_projections
2aedcf6
Fix linting errors
f830048
Remove .cuda() call from proj hidden_state test
dfae3ff
Add projections documentation to nn.LSTM
8113922
Merge branch 'master' into cudnn_projections
b723ec5
Remove cuda placement from proj initial_hidden_state test
e2206b1
Merge branch 'master' into cudnn_projections
54cca58
Address PR comments
107bb28
Expose projections in c++ API
9038cfc
Add c++ integration tests for projections
821f89b
Add check output size test
a593687
Add more projection tests to c++ api
47465c9
Merge branch 'master' into cudnn_projections
be413cf
Add correct type hints
af359cf
Change number of ops in caffe2 onnx tests
ff01240
Add unimplemented call for LSTMs with projections in onnx
8f95f8e
Add onnx test to check projections not supported
bb87f91
Merge branch 'master' into cudnn_projections
4d7d6ca
Add _cudnn_rnn functions to allow_list for bc
6bb8bff
Adjust tests to work on rocm
6834b59
Merge branch 'master' into cudnn_projections
42f1563
Add more precise check for runtime error in tests
f696eda
Merge branch 'master' into cudnn_projections
File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a reason you are not making it pure virtual, like all other functions? If so, please explain it in the comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason is that it's easier to define it here, so that this default behavior can be directly re-used in all function that don't support projections (e.g. quantized cells). I added explanation to this in the code comments as well