Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pad PackedSequences to original batch length #1591

Open
ajfisch opened this issue May 19, 2017 · 7 comments
Open

Pad PackedSequences to original batch length #1591

ajfisch opened this issue May 19, 2017 · 7 comments
Assignees
Labels
hackamonth module: nestedtensor NestedTensor tag see issue #25032 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ajfisch
Copy link
Contributor

ajfisch commented May 19, 2017

The current flow for handling variable length sequences in RNNs is:

packed_input = torch.nn.utils.rnn.pack_padded_sequence(input, lengths)
packed_output = rnn(packed)[0]
output = torch.nn.utils.rnn.pad_packed_sequence(packed_output)

The output size of the Variable returned by pad_packed_sequence is determined by the max length in lengths: https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/rnn.py#L106.

However, this only works in recovering the original size of the input if the max length sequence has no padding (max length == length dim of batched input). For normal, sensible batching this makes sense and should be true.

But if a model is using, say, DataParallel, the batch might be split such that there is extra padding. And the output size from the RNN will be truncated (which might break other things down-stream).

To fix this potentially unexpected behavior, I propose two possible simple patches.

  1. A max_batch_size field is calculated from the original input and added to the PackedSequence namedtuple: https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/rnn.py#L6 and used instead of https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/rnn.py#L106

  2. An optional max_batch_size parameter is added to pad_packed_sequence which would be used as an override.

I prefer 1), but I suppose 2) has the advantage of being fully backwards compatible (even though it is discouraged to be directly creating or meddling with PackedSequence tuples, it is possible).

cc @cpuhrsch

@jekbradbury
Copy link
Contributor

jekbradbury commented May 19, 2017

I agree with your preferred solution even though I think my team is relying on this behavior (which also tripped us up) in our code at the moment

@yikang-li
Copy link

Hi, I encounter the same problem when using DataParallel. Do you have any solution to the problem?

@manasRK
Copy link

manasRK commented Jan 10, 2018

This issue is tripping up our pipeline as well. For now, we are concat-ing dummy tensors and masking to take care of this.

However, is anyone working on this now?

@soumith
Copy link
Member

soumith commented Jan 10, 2018

@manasRK no one's working on it, you can go for it.

@yikang-li
Copy link

Our solution is to use original pad_packed_sequence and then add further padding to the original lengths.

@manasRK
Copy link

manasRK commented Feb 15, 2018

For people who stumble upon this thread in the future, this is how we are handling it currently (for now);

First, the usual pack_padded_sequence and pad_packed_sequence for handling variable length sequences;

seq_len, bsz, n_dims = feats.size()
packed_input = pack_padded_sequence(feats, lengths, batch_first=False)
packed_output, self.hidden = self.lstm(packed_input, self.hidden)
# lstm_out --> seqlen X bsz X hidden_dim
lstm_out, output_lengths = pad_packed_sequence(packed_output, batch_first = False)

Then the hack is implemented as the output size of the Variable returned by pad_packed_sequence is determined by the max length in output_lengths, not seqlen in batch. Also, you may have to hardcode MAXLEN in sequence/loss masking procedures;

if lstm_out.size(0)<seq_len:
    dummy_tensor = autograd.Variable(torch.zeros(seq_len-lstm_out.size(0), bsz, self.hidden_dim))
    lstm_out = torch.cat([lstm_out, dummy_tensor], 0)

Our accuracy metrics have remained stable and predictions in line with our expectations, so I think this hack works well.

houseroad added a commit to houseroad/pytorch that referenced this issue Nov 20, 2018
…fb74b7

Summary:
Previous import was 882c5283c54345d131e8fe5c859e4844dcf7ca8e

Included changes:
- **[45ba661](onnx/onnx@45ba661)**: Handle new types in the switch. (pytorch#1608) <Dmitri Smirnov>
- **[14853b6](onnx/onnx@14853b6)**: Bump docker image version to 230 used in CircleCI (pytorch#1606) <bddppq>
- **[e0993b8](onnx/onnx@e0993b8)**: [onnxifi] Make sure that backend handles run async. (pytorch#1599) <Roman Dzhabarov>
- **[e6965cc](onnx/onnx@e6965cc)**: Introduce SparseTensor ML proto (pytorch#1554) <Dmitri Smirnov>
- **[75b782f](onnx/onnx@75b782f)**: In driver test check the return status of onnxGetBackendIDs (pytorch#1597) <bddppq>
- **[c05b364](onnx/onnx@c05b364)**: Make CI log less verbose (pytorch#1595) <bddppq>
- **[fa568e4](onnx/onnx@fa568e4)**: Loop type shape inferencing (pytorch#1591) <Scott McKay>
- **[937e64c](onnx/onnx@937e64c)**: add uint8 (pytorch#1590) <Lu Fang>
- **[f86e951](onnx/onnx@f86e951)**: Add domain as an optional parameter for make_node function (pytorch#1588) <Young Kim>
- **[ff45588](onnx/onnx@ff45588)**: Remove unreachable code in shape_inference.h (pytorch#1585) <Changming Sun>
- **[f7dcad0](onnx/onnx@f7dcad0)**: Add several hyperbolic function ops. (pytorch#1499) <Sergii Dymchenko>
- **[a60ac7d](onnx/onnx@a60ac7d)**: Add OneHot op to ONNX. (pytorch#1567) <Spandan Tiwari>
- **[f6c3a7e](onnx/onnx@f6c3a7e)**: [compiler flag] Issue a warning if class has virtual method but missing virtual dtor. (pytorch#1583) <Roman Dzhabarov>
- **[88d1784](onnx/onnx@88d1784)**: Fix MaxUnpool shape inference when output_shape is provided as input (pytorch#1578) <Spandan Tiwari>
- **[20041b7](onnx/onnx@20041b7)**: Add type shape inferencing for the If operator (pytorch#1571) <Scott McKay>
- **[d6c4c75](onnx/onnx@d6c4c75)**: Add a virtual destructor to GraphInferencer (pytorch#1574) <Changming Sun>
- **[a339598](onnx/onnx@a339598)**: fix ConvTranspose spec (pytorch#1566) <Wenhao Hu>

Differential Revision: D13049077

fbshipit-source-id: 11133f10bc6b451094d1081e4ce736b02c8b9e2a
houseroad added a commit to houseroad/pytorch that referenced this issue Nov 29, 2018
…002d19

Summary:
Previous import was 882c5283c54345d131e8fe5c859e4844dcf7ca8e

Included changes:
- **[f461f7a](onnx/onnx@f461f7a)**: Show the op's type and name when the shape inference is failed. (pytorch#1623) <Jerry>
- **[ab8aaf9](onnx/onnx@ab8aaf9)**: Add scan test case (pytorch#1586) <G. Ramalingam>
- **[c95357e](onnx/onnx@c95357e)**: link the tutorial (pytorch#1650) <Lu Fang>
- **[d7e2420](onnx/onnx@d7e2420)**: Upgrade label encoder to support more input types (pytorch#1596) <Wei-Sheng Chin>
- **[6425108](onnx/onnx@6425108)**: Add Doc about Adding New Operator into ONNX (pytorch#1647) <Lu Fang>
- **[295889c](onnx/onnx@295889c)**: use an empty initializer to create map (pytorch#1643) <Lu Fang>
- **[e38f3ec](onnx/onnx@e38f3ec)**: Remove redundant const (pytorch#1639) <daquexian>
- **[ea694bf](onnx/onnx@ea694bf)**: implement fuse reduce->unsqueeze + fix assumption in nop_dropout pass (pytorch#1565) <Armen>
- **[6db386e](onnx/onnx@6db386e)**: make output shape clear enough for Softmax family (pytorch#1634) <Lu Fang>
- **[2b67c6e](onnx/onnx@2b67c6e)**: fix batchnorm doc (pytorch#1633) <Lu Fang>
- **[c901784](onnx/onnx@c901784)**: remove inappropriate consts (pytorch#1632) <Lu Fang>
- **[de82119](onnx/onnx@de82119)**: Shape inference fix for broadcast, concat and scan (pytorch#1594) <KeDengMS>
- **[d7ffe3b](onnx/onnx@d7ffe3b)**: Update Optimizer Docs (pytorch#1607) <Armen>
- **[d09d139](onnx/onnx@d09d139)**: mark PROTOBUF_INCLUDE_DIRS as BUILD_INTERFACE (pytorch#1466) <Yuta Okamoto>
- **[eb4b7c2](onnx/onnx@eb4b7c2)**: allow variadic parameters of different types (pytorch#1615) <G. Ramalingam>
- **[4166246](onnx/onnx@4166246)**: Fix onnxifi test (pytorch#1617) <Yinghai Lu>
- **[6706a4d](onnx/onnx@6706a4d)**: Fix a bug in vector address access (pytorch#1598) <Raymond Yang>
- **[ae39866](onnx/onnx@ae39866)**: Separate types of inputs 1 and 2 in OneHot op. (pytorch#1610) <Spandan Tiwari>
- **[45ba661](onnx/onnx@45ba661)**: Handle new types in the switch. (pytorch#1608) <Dmitri Smirnov>
- **[14853b6](onnx/onnx@14853b6)**: Bump docker image version to 230 used in CircleCI (pytorch#1606) <bddppq>
- **[e0993b8](onnx/onnx@e0993b8)**: [onnxifi] Make sure that backend handles run async. (pytorch#1599) <Roman Dzhabarov>
- **[e6965cc](onnx/onnx@e6965cc)**: Introduce SparseTensor ML proto (pytorch#1554) <Dmitri Smirnov>
- **[75b782f](onnx/onnx@75b782f)**: In driver test check the return status of onnxGetBackendIDs (pytorch#1597) <bddppq>
- **[c05b364](onnx/onnx@c05b364)**: Make CI log less verbose (pytorch#1595) <bddppq>
- **[fa568e4](onnx/onnx@fa568e4)**: Loop type shape inferencing (pytorch#1591) <Scott McKay>
- **[937e64c](onnx/onnx@937e64c)**: add uint8 (pytorch#1590) <Lu Fang>
- **[f86e951](onnx/onnx@f86e951)**: Add domain as an optional parameter for make_node function (pytorch#1588) <Young Kim>
- **[ff45588](onnx/onnx@ff45588)**: Remove unreachable code in shape_inference.h (pytorch#1585) <Changming Sun>
- **[f7dcad0](onnx/onnx@f7dcad0)**: Add several hyperbolic function ops. (pytorch#1499) <Sergii Dymchenko>
- **[a60ac7d](onnx/onnx@a60ac7d)**: Add OneHot op to ONNX. (pytorch#1567) <Spandan Tiwari>
- **[f6c3a7e](onnx/onnx@f6c3a7e)**: [compiler flag] Issue a warning if class has virtual method but missing virtual dtor. (pytorch#1583) <Roman Dzhabarov>
- **[88d1784](onnx/onnx@88d1784)**: Fix MaxUnpool shape inference when output_shape is provided as input (pytorch#1578) <Spandan Tiwari>
- **[20041b7](onnx/onnx@20041b7)**: Add type shape inferencing for the If operator (pytorch#1571) <Scott McKay>
- **[d6c4c75](onnx/onnx@d6c4c75)**: Add a virtual destructor to GraphInferencer (pytorch#1574) <Changming Sun>
- **[a339598](onnx/onnx@a339598)**: fix ConvTranspose spec (pytorch#1566) <Wenhao Hu>

Differential Revision: D13263831

fbshipit-source-id: 0c158dd12c45d704b6f37f63f3d74ed34ef2f534
facebook-github-bot pushed a commit that referenced this issue Nov 30, 2018
…002d19 (#14568)

Summary:
Pull Request resolved: #14568

Previous import was 882c5283c54345d131e8fe5c859e4844dcf7ca8e

Included changes:
- **[f461f7a](onnx/onnx@f461f7a)**: Show the op's type and name when the shape inference is failed. (#1623) <Jerry>
- **[ab8aaf9](onnx/onnx@ab8aaf9)**: Add scan test case (#1586) <G. Ramalingam>
- **[c95357e](onnx/onnx@c95357e)**: link the tutorial (#1650) <Lu Fang>
- **[d7e2420](onnx/onnx@d7e2420)**: Upgrade label encoder to support more input types (#1596) <Wei-Sheng Chin>
- **[6425108](onnx/onnx@6425108)**: Add Doc about Adding New Operator into ONNX (#1647) <Lu Fang>
- **[295889c](onnx/onnx@295889c)**: use an empty initializer to create map (#1643) <Lu Fang>
- **[e38f3ec](onnx/onnx@e38f3ec)**: Remove redundant const (#1639) <daquexian>
- **[ea694bf](onnx/onnx@ea694bf)**: implement fuse reduce->unsqueeze + fix assumption in nop_dropout pass (#1565) <Armen>
- **[6db386e](onnx/onnx@6db386e)**: make output shape clear enough for Softmax family (#1634) <Lu Fang>
- **[2b67c6e](onnx/onnx@2b67c6e)**: fix batchnorm doc (#1633) <Lu Fang>
- **[c901784](onnx/onnx@c901784)**: remove inappropriate consts (#1632) <Lu Fang>
- **[de82119](onnx/onnx@de82119)**: Shape inference fix for broadcast, concat and scan (#1594) <KeDengMS>
- **[d7ffe3b](onnx/onnx@d7ffe3b)**: Update Optimizer Docs (#1607) <Armen>
- **[d09d139](onnx/onnx@d09d139)**: mark PROTOBUF_INCLUDE_DIRS as BUILD_INTERFACE (#1466) <Yuta Okamoto>
- **[eb4b7c2](onnx/onnx@eb4b7c2)**: allow variadic parameters of different types (#1615) <G. Ramalingam>
- **[4166246](onnx/onnx@4166246)**: Fix onnxifi test (#1617) <Yinghai Lu>
- **[6706a4d](onnx/onnx@6706a4d)**: Fix a bug in vector address access (#1598) <Raymond Yang>
- **[ae39866](onnx/onnx@ae39866)**: Separate types of inputs 1 and 2 in OneHot op. (#1610) <Spandan Tiwari>
- **[45ba661](onnx/onnx@45ba661)**: Handle new types in the switch. (#1608) <Dmitri Smirnov>
- **[14853b6](onnx/onnx@14853b6)**: Bump docker image version to 230 used in CircleCI (#1606) <bddppq>
- **[e0993b8](onnx/onnx@e0993b8)**: [onnxifi] Make sure that backend handles run async. (#1599) <Roman Dzhabarov>
- **[e6965cc](onnx/onnx@e6965cc)**: Introduce SparseTensor ML proto (#1554) <Dmitri Smirnov>
- **[75b782f](onnx/onnx@75b782f)**: In driver test check the return status of onnxGetBackendIDs (#1597) <bddppq>
- **[c05b364](onnx/onnx@c05b364)**: Make CI log less verbose (#1595) <bddppq>
- **[fa568e4](onnx/onnx@fa568e4)**: Loop type shape inferencing (#1591) <Scott McKay>
- **[937e64c](onnx/onnx@937e64c)**: add uint8 (#1590) <Lu Fang>
- **[f86e951](onnx/onnx@f86e951)**: Add domain as an optional parameter for make_node function (#1588) <Young Kim>
- **[ff45588](onnx/onnx@ff45588)**: Remove unreachable code in shape_inference.h (#1585) <Changming Sun>
- **[f7dcad0](onnx/onnx@f7dcad0)**: Add several hyperbolic function ops. (#1499) <Sergii Dymchenko>
- **[a60ac7d](onnx/onnx@a60ac7d)**: Add OneHot op to ONNX. (#1567) <Spandan Tiwari>
- **[f6c3a7e](onnx/onnx@f6c3a7e)**: [compiler flag] Issue a warning if class has virtual method but missing virtual dtor. (#1583) <Roman Dzhabarov>
- **[88d1784](onnx/onnx@88d1784)**: Fix MaxUnpool shape inference when output_shape is provided as input (#1578) <Spandan Tiwari>
- **[20041b7](onnx/onnx@20041b7)**: Add type shape inferencing for the If operator (#1571) <Scott McKay>
- **[d6c4c75](onnx/onnx@d6c4c75)**: Add a virtual destructor to GraphInferencer (#1574) <Changming Sun>
- **[a339598](onnx/onnx@a339598)**: fix ConvTranspose spec (#1566) <Wenhao Hu>

Reviewed By: zrphercule

Differential Revision: D13263831

fbshipit-source-id: a2ff22c6454e2430429e5a7d18d21661a7ffb0cb
@cpuhrsch cpuhrsch added the module: nestedtensor NestedTensor tag see issue #25032 label Nov 21, 2019
@cpuhrsch cpuhrsch self-assigned this Nov 21, 2019
@cpuhrsch
Copy link
Contributor

Since we have a solution I'm using this issue to track work that could be replaced by NestedTensors and therefore am assigning myself.

@cpuhrsch cpuhrsch added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 21, 2019
jjsjann123 pushed a commit to jjsjann123/pytorch that referenced this issue Apr 18, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
hackamonth module: nestedtensor NestedTensor tag see issue #25032 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Issue Categories
neural-nets
Issue Status
Uncategorized
Development

No branches or pull requests

6 participants