Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataParallel is not compatible with pack_padded_sequence #2312

Open
ZiJianZhao opened this issue Aug 7, 2017 · 6 comments
Open

DataParallel is not compatible with pack_padded_sequence #2312

ZiJianZhao opened this issue Aug 7, 2017 · 6 comments
Labels
awaiting response (this tag is deprecated) This tag is deprecated while we figure out what to do with it module: data parallel triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ZiJianZhao
Copy link

If the model has pack_padded_sequence, then with DataParallel module it will output error "ValueError: lengths array has incorrect size"

@ngimel
Copy link
Collaborator

ngimel commented Aug 7, 2017

Can you please provide minimum reproducer?

@soumith soumith added this to Uncategorized in Issue Status Aug 23, 2017
@soumith soumith added the awaiting response (this tag is deprecated) This tag is deprecated while we figure out what to do with it label Aug 30, 2017
@soumith
Copy link
Member

soumith commented Aug 30, 2017

@ZiJianZhao still waiting on a response.

@soumith soumith added this to Crashes / Segfaults / Errors in Issue Categories Aug 30, 2017
@jgc128
Copy link

jgc128 commented Nov 7, 2017

Hi,

I have the same error (there's also this issue #1591). The code below works on one GPU (CUDA_VISIBLE_DEVICES=0 python pack_padded_sequence_data_parallel.py), but fails with "ValueError: lengths array has incorrect size" on two GPUs (CUDA_VISIBLE_DEVICES=0,1 python pack_padded_sequence_data_parallel.py):

import numpy as np
import torch
from torch.autograd import Variable


class RNNDataParallel(torch.nn.Module):
    def __init__(self):
        super(RNNDataParallel, self).__init__()

    def forward(self, inputs, lengths):
        packed = torch.nn.utils.rnn.pack_padded_sequence(inputs, lengths, batch_first=True)

        return packed


model = RNNDataParallel()
model = torch.nn.DataParallel(model)
model = model.cuda()

inputs = Variable(torch.from_numpy(np.array([
    [1, 2, 3],
    [4, 5, 0],
])))
lengths = [3, 2]

packed = model(inputs, lengths)

print(packed)

My PyTorch version is 0.2.0+e02f7bf

@ahmedmagdiosman
Copy link

ahmedmagdiosman commented Nov 20, 2017

I encountered the same issue as @jgc128 .

EDIT: I think the issue is that DataParallel does not do slice CPU data like the lengths list.

EDIT2: I "fixed" this by transforming the lengths to a Variable(LongTensor.cuda()) before starting the forward pass and reverting it to a List before calling pack_padded_sequence.

@jekbradbury
Copy link
Contributor

pack_padded_sequence now supports the lengths being provided as a Tensor or Variable, I think?

@ahmedmagdiosman
Copy link

@jekbradbury Not on my build (conda pytorch 0.2.0). Even so, the issue lies within DataParallel for not slicing lengths. I fixed the issue with a simple hack (check my 2nd edit on the previous comment).

zou3519 pushed a commit to zou3519/pytorch that referenced this issue Mar 30, 2018
@gchanan gchanan added module: data parallel triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 11, 2020
samnordmann pushed a commit to samnordmann/pytorch that referenced this issue Jan 12, 2023
fixing compiler warning
fixing lintrunner entries
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
awaiting response (this tag is deprecated) This tag is deprecated while we figure out what to do with it module: data parallel triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Issue Categories
Crashes / Segfaults / Errors
Issue Status
Uncategorized
Development

No branches or pull requests

7 participants