Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue for DataParallel #8637

Closed
BestSonny opened this issue Jun 19, 2018 · 11 comments
Closed

Issue for DataParallel #8637

BestSonny opened this issue Jun 19, 2018 · 11 comments

Comments

@BestSonny
Copy link
Contributor

When I tried to use multi-gpu trainning.

  1. If I wrap the model forward function as follows:
def __init__(...):
        self.operation_function = self._gaussian

def forward(self, x, z):
        output = self.operation_function(x, z)

def _gaussian(self, x, z):
        Real forward codes
        ...

Then

gpu_num = torch.cuda.device_count()
print('GPU NUM: {:2d}'.format(gpu_num))
if gpu_num > 1:
    model = torch.nn.DataParallel(model, list(range(gpu_num))).cuda()

I get one error during the multi-gpu trainning

File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 468, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/parallel/data_parallel.py", line 123, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/parallel/data_parallel.py", line 133, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/parallel/parallel_apply.py", line 77, in parallel_apply
    raise output
RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)
  1. If I directly put all codes of _gaussian function under forward function. The multi-gpu trainning works
def forward(self, x, z):
        Real forward codes
        ...

@soumith Do you have any comments about this?

@BestSonny
Copy link
Contributor Author

BestSonny commented Jun 19, 2018

A demo code is provided to reproduce the errors. testModude2 works fine while testModude is not good.

import torch
from torch import nn

class testModule(nn.Module):
    def __init__(self):
        super(testModule, self).__init__()
        self.g = nn.Conv2d(in_channels=1, out_channels=1,
                         kernel_size=1, stride=1, padding=0)
        self.operation_function = self._realOperation

    def forward(self, x):
        output = self.operation_function(x)
        return output

    def _realOperation(self, x):
        x = self.g(x)
        return x

class testModule2(nn.Module):
    def __init__(self):
        super(testModule2, self).__init__()
        self.g = nn.Conv2d(in_channels=1, out_channels=1,
                         kernel_size=1, stride=1, padding=0)
    def forward(self, x):
        x = self.g(x)
        return x

if __name__ == '__main__':
        input = torch.rand(4, 1, 1, 1).cuda()
        net = testModule()
        net2 = testModule2()
        gpu_num = torch.cuda.device_count()
        print('GPU NUM: {:2d}'.format(gpu_num))
        if gpu_num > 1:
            net = torch.nn.DataParallel(net, list(range(gpu_num))).cuda()
            net2 = torch.nn.DataParallel(net2, list(range(gpu_num))).cuda()
        out2 = net2(input)
        print(out2.size())
        out = net(input)
        print(out.size())

@ssnl
Copy link
Collaborator

ssnl commented Jun 19, 2018

Yeah, testModule won't work because you saved the method self._realOperation, which is bound to this particular testModule instance, as the attribute self.operation_function. When broadcasting the module to different GPUs, this attribute, as it is not a tensor, is just simply duplicated, which means that all broadcast copies of this module have the attribute refer to the same bound method, and that this method is bound to the same instance and thus using the same self.g, which has all parameters only on GPU 0. Therefore it errors on GPU 1.

In testModule2, in forward of each broadcast copy, the dynamically found self.g is the g attribute of that copy, whose parameters are broadcast to corresponding GPU.

If you just do not save the method as an attribute, the code should work fine. E.g., directly calling self._realOperation in forward, writing self.operation_function as another method of the class, etc.

In theory we can fix this by checking the broadcast module's __dict__ and see if there are methods of this instance bounded as attributes. But I don't think that this is very common.

@maozhiqiang
Copy link

i use DataParallel to computer https://github.com/JohnVinyard/experiments/blob/master/audio-loss/audio_loss.py
It is worked use single gpu! but multiply gpu have problem!
my code like
samplerate = SR22050()
scale = BarkScale(FrequencyBand(50, samplerate.nyquist - 300), n_bands=512)
perceptual_loss = PerceptualLoss(
scale,
samplerate,
lap=1,
log_factor=10,
basis_size=512,
frequency_weighting=AWeighting(),
cosine_similarity=True).to(device)

# perceptual_loss(input=, target=)
from audio import load_wav
# print(1 + perceptual_loss(input=torch.Tensor([a[:41885]]).cuda(), target=torch.Tensor([b[:41885]]).cuda()))
loss = 1 + torch.nn.parallel.data_parallel(perceptual_loss, (torch.Tensor([a[:134301]]).cuda()

the error like this:
RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)
how to solve this!

@willprice
Copy link
Contributor

@ssnl, is it possible to revisit the decision of supporting this in pytorch? pretrained-models.pytorch uses this pattern to build atop the torchvision models. Alternatively, are there any solutions that we can apply to pretrained-models.pytorch to resolve this issue?

@hegc
Copy link

hegc commented Aug 6, 2019

@willprice You can see this. Cadene/pretrained-models.pytorch#145

@ssnl
Copy link
Collaborator

ssnl commented Aug 6, 2019

reopen for discussion!

@ssnl
Copy link
Collaborator

ssnl commented Aug 6, 2019

cc @apaszke @colesbury @soumith @fmassa for opinions

@willprice
Copy link
Contributor

willprice commented Aug 6, 2019

@willprice You can see this. Cadene/pretrained-models.pytorch#145

The issue with this pull request, while fine for new users of pretrained-models, is that it changes the classes of the models returned which breaks downstream code relying on instance checks.

@apaszke
Copy link
Contributor

apaszke commented Aug 8, 2019

I still don't think we should be handling this. Our cloning logic is already quite expensive, and we already take a fair amount of shortcuts to make it faster. Doing something like this would require us to investigate every single attribute of each module, and there's just too many of those. Remember that this is not a one-time cost, it's happening at every iteration!

@wmmxk
Copy link

wmmxk commented Aug 24, 2019

So I should put all operations of a model in one method, which is the forward method, to make it work on multiple gpus using Dataparallel. Is it right?

@colesbury
Copy link
Member

@wmmxk you should see SsnL's post above. Not all operations need to be in the same method, but you dynamically binding functions to attributes won't work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

9 participants