-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for variable length sequences in RNNs #873
Conversation
torch/nn/utils/rnn.py
Outdated
|
||
if batch_first: | ||
output = output.transpose(0, 1) | ||
return output |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for picking this up for me @apaszke ; this looks great!
I'm going to let the PyOpenNMT people know about this, because I think switching to it could simplify some of the PyOpenNMT code, e.g. https://github.com/pytorch/examples/blob/master/OpenNMT/onmt/Translator.py#L56-L79
return descriptor | ||
|
||
|
||
def descriptor_sequence(tensor, batch_sizes): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
self.dropout_seed = torch.IntTensor(1).random_()[0] | ||
self.dropout_state = dropout_state | ||
|
||
def forward_extended(self, input, weight, hx): | ||
|
||
assert(cudnn.is_acceptable(input)) | ||
assert cudnn.is_acceptable(input) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. @ngimel might want to review the cudnn parts.
One thought: there's an implicit (?) invariant that differentiable arguments to autograd functions are of type Variable
; there are various places that check for this. Is the use of PackedSequence
instead of a Variable
going to break anything in e.g. DataParallel
?
torch/nn/modules/rnn.py
Outdated
@@ -234,6 +236,8 @@ class LSTM(RNNBase): | |||
|
|||
Inputs: input, (h_0, c_0) | |||
- **input** (seq_len, batch, input_size): tensor containing the features of the input sequence. | |||
The input can also be a packed variable sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence` |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/utils/rnn.py
Outdated
@@ -6,19 +6,43 @@ | |||
PackedSequence = namedtuple('PackedSequence', ['data', 'batch_sizes']) | |||
|
|||
|
|||
def pack_padded_sequence(tensor, lengths, batch_first=False): | |||
def pack_padded_sequence(input, lengths, batch_first=False): | |||
"""Packes a Variable containing padded sequences of variable length. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@adamlerer, good point re: DataParallel. But packed tensor can't be scattered, so the way to do multi-GPU with this is to have a module that accepts padded tensor Variable as input, and wrap this module in DataParallel. |
Fixes pytorch#873 Two changes in this PR: Updated BroadcastingChunk to properly set correct sizes/strides for outputs; Update fuser guard logic to recognize dimension with stride == 1 to be contiguous; Note: stride==1 dimension is considered to be contiguous in PE. We have to stay consistent with that, otherwise, we'll keep putting on a guard that will fail later and we would reconstruct until we reach the bailout depth.
…021-11-08 IFU-master-2021-11-08
The PR is still lacking the docs and pep8 fixes, so it's not ready for merge yet, but I wanted to get it out today, so it can be reviewed. I'll fix any comments tomorrow.
Fixes #789.
cc @jekbradbury