Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataParallel with Torch 1.5 #40457

Open
mnxu7979 opened this issue Jun 23, 2020 · 10 comments
Open

DataParallel with Torch 1.5 #40457

mnxu7979 opened this issue Jun 23, 2020 · 10 comments
Labels
high priority module: data parallel module: regression It used to work, and now it doesn't triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mnxu7979
Copy link

mnxu7979 commented Jun 23, 2020

馃悰 Bug

I tried to leverage multi-gpu using nn.DataParallel. I got an error with torch 1.5, but the same code work will work with torch 1.4.

To Reproduce

I tested it with the code in this tutorial from PyTorch.org

Following code can be used to reproduce the error:

import torch 
import torch.nn as nn 

from torch.utils.data import Dataset, DataLoader 


# params 
input_size = 5
output_size = 2 

batch_size = 32
data_size = 32

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# dataloader 
class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length 
        self.data = torch.randn(length, size)
    
    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len 

rand_loader = DataLoader(dataset=RandomDataset(input_size, data_size), batch_size=batch_size, shuffle=True)


# simple model 
class Model(nn.Module):
    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, input):
        output = self.fc(input)
        print('\t\tparameters is located at', next(self.parameters()).device)
        return output 
        
model = Model(input_size, output_size)

model = nn.DataParallel(model)
model.to(device)
for batch in iter(rand_loader):
    batch = batch.to(device)
    model(batch)

And i got the following error message:

---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
<ipython-input-1-c88fccd98bfb> in <module>
     46 for batch in iter(rand_loader):
     47     batch = batch.to(device)
---> 48     model(batch)
     49

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
    153             return self.module(*inputs[0], **kwargs[0])
    154         replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
--> 155         outputs = self.parallel_apply(replicas, inputs, kwargs)
    156         return self.gather(outputs, self.output_device)
    157

/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/data_parallel.py in parallel_apply(self, replicas, inputs, kwargs)
    163
    164     def parallel_apply(self, replicas, inputs, kwargs):
--> 165         return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
    166
    167     def gather(self, outputs, output_device):

/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/parallel_apply.py in parallel_apply(modules, inputs, kwargs_tup, devices)
     83         output = results[i]
     84         if isinstance(output, ExceptionWrapper):
---> 85             output.reraise()
     86         outputs.append(output)
     87     return outputs

/usr/local/lib/python3.6/dist-packages/torch/_utils.py in reraise(self)
    393             # (https://bugs.python.org/issue2651), so we work around it.
    394             msg = KeyErrorMessage(msg)
--> 395         raise self.exc_type(msg)

StopIteration: Caught StopIteration in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "<ipython-input-1-c88fccd98bfb>", line 39, in forward
    print('\t\tparameters is located at', next(self.parameters()).device)
StopIteration

Expected behavior

With torch 1.4, i got the following output without any error.

                parameters is located at cuda:0
                parameters is located at cuda:1
                parameters is located at cuda:2
                parameters is located at cuda:3

Environment

Collecting environment information...
PyTorch version: 1.5.1
Is debug build: No
CUDA used to build PyTorch: 10.2

OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: Could not collect

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.2.89
GPU models and configuration:
GPU 0: Tesla V100-SXM2-32GB
GPU 1: Tesla V100-SXM2-32GB
GPU 2: Tesla V100-SXM2-32GB
GPU 3: Tesla V100-SXM2-32GB

Nvidia driver version: 418.87.01
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5

Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] torch==1.5.1
[conda] Could not collect

cc @ezyang @gchanan @zou3519

@zou3519 zou3519 added high priority module: data parallel module: regression It used to work, and now it doesn't triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jun 23, 2020
@zou3519
Copy link
Contributor

zou3519 commented Jun 23, 2020

This appears to be a regression so I am tentatively labeling it as high-pri.

@mrshenli
Copy link
Contributor

This is known regression. After #33907, parameters() on the replicated models are no longer populated. So accessing them in the forward pass might lead to errors. The reason for adding this change is because parameters in those replicated models are not leaves, and hence should be added as parameters.

@mrshenli
Copy link
Contributor

If you really need to access those parameters, one hacky solution is to read from the _former_parameters field (available in v1.5.1 but not v1.5.0). But this is not recommended and we cannot guarantee that this attribute will always be there in future releases.

def parameters(m, recurse=True):
def model_parameters(m):
ps = m._former_parameters.values() \
if hasattr(m, "_former_parameters") \
else m.parameters(recurse=False)
for p in ps:
yield p
for m in m.modules() if recurse else [m]:
for p in model_parameters(m):
yield p

@mrshenli
Copy link
Contributor

cc @ngimel

@mnxu7979
Copy link
Author

Thank you for answering. I don't need to access those parameters directly, but this issue/bug caused a crash when I was using Huggingface Transformers package. I downgraded to 1.4. it worked out just fine.

@gchanan
Copy link
Contributor

gchanan commented Jun 29, 2020

is there an issue for the intersection of this and HuggingFace Transformers?

@ezyang
Copy link
Contributor

ezyang commented Jun 29, 2020

According to @ngimel, HuggingFace already has an update to deal with this BC breakage.

@ngimel
Copy link
Collaborator

ngimel commented Jun 29, 2020

@wmmxk
Copy link

wmmxk commented Sep 21, 2020

I run into a similar issue. It turns out self.parameters() is called when figuring out which gpu is used.

In my case, I make the following change to the implementation in hugging face. It works.


          device=input_ids.device, # after change
               # device=next(self.parameters()).device,

@daniel347x
Copy link

@wmmxk Thanks for this heads up!

I can confirm this is a bug and your fix (as well as an additional fix, noted below) resolves the issue.

Two changes I made:

In transformers/generation_utils.py, change device=next(self.parameters()).device, to device=input_ids.device,
In transformers/modeling_gpt2.py, change attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) to attention_mask = attention_mask.to(dtype=input_ids.dtype)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: data parallel module: regression It used to work, and now it doesn't triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

9 participants