-
Notifications
You must be signed in to change notification settings - Fork 21.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add option to automatically handle unsorted variable-length sequences…
… in RNNs (#15225) Summary: Fixes #3584. Motivation: manually sorting sequences, packing them, and then unsorting them is something a lot of users have complained about doing, especially when we can offer library support for them. Overview: we internally sort sequences before packing them and store a list of `unsorted_indices` that represent how to unsort the sequences inside PackedSequence. The packing helper functions return PackedSequence with the `permutation` field and the unpacking helper functions use it to unsort. To implement this, the following changes were made: - PackedSequence now keeps `sorted_indices` and `unsorted_indices`. These two can be thought of as permutations and are inverses of each other. `sorted_indices` is how the sequences were sorted; `unsorted_indices` is how to unsort the sequences. - Added an `enforce_sorted` argument to pack_sequence and pack_padded_sequence that maintains the legacy behavior of error-ing out on unsorted-sequences. When `enforce_sorted=True`, these functions maintain their ONNX exportability. - pack_sequence(sequences, enforce_sorted) takes in unsorted sequences. - pack_padded_sequence can take in a padded tensor that represents padded, unsorted sequences. - pad_packed_sequence unsorts the PackedSequence such that it is still the inverse operation of packed_padded_sequence. - RNNs apply `sort_indices` to their input hidden state and apply `unsort_indices` to their output hidden state. This is to ensure that the hidden state batches correspond to the user's ordering of input sequences. NOT BC-Breaking - The default for pack_sequence and pack_padded_sequence is `enforce_sorted=True` to avoid breaking ONNX export. To use the new functionality, pass in `enforce_sorted=False` Testing Plan - Modified TestNN.test_pack_sequence, TestNN.test_packed_padded_sequence, and TestNN.test_variable_sequence (RNN test) to check the behavior of unsorted sequences, sorted sequences, and sorted sequences with enforce_sorted=True - test/test_jit.py has a test to see if RNNs are exportable with enforce_sorted=True cc colesbury Pull Request resolved: #15225 Reviewed By: soumith Differential Revision: D13507138 Pulled By: zou3519 fbshipit-source-id: b871dccd6abefffca81bc4e3efef1873faa242ef
- Loading branch information
1 parent
52699f0
commit 3353064
Showing
3 changed files
with
306 additions
and
121 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.