Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nn.dataparallel with multiple output, weird gradient result None #15716

Open
eric-zhenhai opened this issue Jan 3, 2019 · 6 comments
Open

Nn.dataparallel with multiple output, weird gradient result None #15716

eric-zhenhai opened this issue Jan 3, 2019 · 6 comments
Assignees
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@eric-zhenhai
Copy link

馃悰 Bug

Under PyTorch 1.0, nn.DataParallel() wrapper for models with multiple outputs does not calculate gradients properly.

To Reproduce

On servers with >=2 GPUs, under PyTorch 1.0.0
Steps to reproduce the behavior:

  1. Use the code in below:
import torch.nn as nn
import torch
import torch.nn.functional as F

DEVICE = torch.device('cuda:0')


class NN4(nn.Module):
    def __init__(self):
        super(NN4, self).__init__()
        self.fc1 = nn.Linear(8, 4)
        self.fc21 = nn.Linear(4, 1)

    def forward(self, x):
        x = F.selu(self.fc1(x))
        x1 = torch.sigmoid(self.fc21(x))
        # return x, x  # not None
        return x, x1  # None


def test_NN4():
    images = torch.randn(4, 8).to(DEVICE)
    fimages = torch.randn(4, 8).to(DEVICE)

    D = NN4().to(DEVICE)
    D = nn.DataParallel(D)
    D.zero_grad()

    d_loss = D(images)[0].mean() - D(fimages)[0].mean()
    print('d_loss: -->', d_loss)
    d_loss.backward()

    print('-------->>>')
    aaa = list(D.named_parameters())
    print(aaa[0][0])
    print(aaa[0][1].grad)

    D2 = NN4().to(DEVICE)
    D2.zero_grad()

    d2_loss = D2(images)[0].mean() - D2(fimages)[0].mean()
    print('d2_loss: -->', d2_loss)
    d2_loss.backward()

    print('-------->>>')
    aaa2 = list(D2.named_parameters())
    print(aaa2[0][0])
    print(aaa2[0][1].grad)

Then run the code with "CUDA_VISIBLE_DEVICES=0,1 python dp_test.py" in console. Under PyTorch 1.0.0, I get:

d_loss: --> tensor(0.1488, device='cuda:0', grad_fn=<SubBackward0>)
-------->>>
module.fc1.weight
None
d2_loss: --> tensor(0.0149, device='cuda:0', grad_fn=<SubBackward0>)
-------->>>
fc1.weight
tensor([[ 0.0284, -0.1972,  0.1553, -0.3356,  0.2737, -0.2083,  0.1420, -0.3533],
        [ 0.0473, -0.1277,  0.0903, -0.3214,  0.2385, -0.1815,  0.0369, -0.1991],
        [ 0.0231, -0.0949,  0.1218, -0.3591,  0.1832, -0.2311,  0.0685, -0.1934],
        [ 0.0858, -0.1129,  0.1216, -0.3774,  0.3795, -0.1308, -0.0006, -0.1790]],
       device='cuda:0')

However, under PyTorch 0.4.0, I get:

d_loss: --> tensor(0.1650, device='cuda:0')
-------->>>
module.fc1.weight
tensor([[-0.2463,  0.0740, -0.2929, -0.2576, -0.0346,  0.1679,  0.1501,
         -0.2375],
        [-0.2666,  0.1135, -0.3788, -0.2865, -0.0519, -0.0217,  0.0564,
         -0.2942],
        [-0.2802,  0.1207, -0.3556, -0.2959, -0.0245, -0.0106,  0.0902,
         -0.2851],
        [-0.3193,  0.0788, -0.4258, -0.2705, -0.1212,  0.0063,  0.0322,
         -0.2649]], device='cuda:0')
d2_loss: --> tensor(1.00000e-02 *
       8.7814, device='cuda:0')
-------->>>
fc1.weight
tensor([[-0.3051,  0.1011, -0.3452, -0.2829, -0.0318, -0.0299,  0.0642,
         -0.2442],
        [-0.2536,  0.1279, -0.3869, -0.3891, -0.0362,  0.0412,  0.1000,
         -0.3384],
        [-0.3321,  0.0059, -0.4514, -0.2517, -0.1013,  0.0374,  0.0124,
         -0.1985],
        [-0.3147,  0.0331, -0.3343, -0.2498, -0.0903, -0.0668,  0.0555,
         -0.2360]], device='cuda:0')

Expected behavior

aaa[0][1].grad should not be none under PyTorch 1.0.0

Environment

PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 16.04.5 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.5.1

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: Tesla V100-SXM2-16GB
GPU 1: Tesla V100-SXM2-16GB
GPU 2: Tesla V100-SXM2-16GB
GPU 3: Tesla V100-SXM2-16GB

Nvidia driver version: 396.44
cuDNN version: Probably one of the following:
/usr/local/cuda-8.0/lib64/libcudnn.so.6.0.21
/usr/local/cuda-8.0/lib64/libcudnn_static.a
/usr/local/cuda-9.0/lib64/libcudnn.so.7.3.1
/usr/local/cuda-9.0/lib64/libcudnn_static.a
/usr/local/cuda-9.1/lib64/libcudnn.so.7.0.5
/usr/local/cuda-9.1/lib64/libcudnn_static.a
/usr/local/cuda-9.2/lib64/libcudnn.so.7.3.1
/usr/local/cuda-9.2/lib64/libcudnn_static.a

Additional context

@ZhichengHuang
Copy link

@eric-zhenhai
Copy link
Author

eric-zhenhai commented Jan 4, 2019

Agree with:

I think this issue is the same reason v1.0.0 nn.utils.weight_norm seems to nullify gradients of unrelated parameters if wrapped in DataParallel

I forgot to mention that if I change the return of the forward() call of NN4 from
return x, x1 to return x, x then the gradient is calculated under PyTorch 1.0.0, this make things more mysterious.

@ailzhang ailzhang added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jan 7, 2019
@teng-li teng-li removed the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jan 9, 2019
@fmassa fmassa assigned soumith and unassigned teng-li Jan 14, 2019
@cpuhrsch cpuhrsch removed the bug label Apr 9, 2019
@gchanan gchanan added oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 10, 2019
@keunhong
Copy link

keunhong commented May 21, 2019

I'm also running into this issue where if I compute the loss based on tensors returned in a dict from forward then I get None as the gradients if I'm using DataParallel

Are there any know workarounds for this?

EDIT: Seems like someone came up with a workaround here: r9y9/wavenet_vocoder@6b9c932

Here's my workaround for the module version based on the workaround from above.

class DataParallelFix(nn.DataParallel):
    """
    Temporary workaround for https://github.com/pytorch/pytorch/issues/15716.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self._replicas = None
        self._outputs = None

    def forward(self, *inputs, **kwargs):
        if not self.device_ids:
            return self.module(*inputs, **kwargs)

        for t in chain(self.module.parameters(), self.module.buffers()):
            if t.device != self.src_device_obj:
                raise RuntimeError(
                    "module must have its parameters and buffers "
                    "on device {} (device_ids[0]) but found one of "
                    "them on device: {}".format(self.src_device_obj,
                                                t.device))

        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
        if len(self.device_ids) == 1:
            return self.module(*inputs[0], **kwargs[0])

        self._replicas = self.replicate(self.module,
                                  self.device_ids[:len(inputs)])
        self._outputs = self.parallel_apply(self._replicas, inputs, kwargs)

        return self.gather(self._outputs, self.output_device)

This just keeps replicas and outputs in memory.

@TheButlah
Copy link

TheButlah commented Jul 9, 2019

Was this fixed on the latest release without having to use the workaround posted?

@shaun95
Copy link

shaun95 commented Mar 19, 2020

I tried eric-zhenhai test code on Pytorch==1.4.0 on Ubuntu with 3 GPUs and this is the output I get

d_loss: --> tensor(0.0124, device='cuda:0', grad_fn=)
-------->>>
module.fc1.weight
tensor([[-0.2214, 0.1978, 0.0139, -0.4091, 0.0889, -0.1214, 0.1038, -0.0677],
[-0.1105, 0.0551, -0.0888, -0.5148, 0.1117, -0.2265, 0.1703, 0.0035],
[-0.1095, 0.1248, -0.0389, -0.3650, 0.1202, -0.1276, 0.1087, -0.0117],
[-0.2190, 0.0884, -0.0526, -0.4507, 0.0877, -0.1610, 0.1633, -0.0042]],
device='cuda:0')
d2_loss: --> tensor(0.0059, device='cuda:0', grad_fn=)
-------->>>
fc1.weight
tensor([[-0.0805, 0.1024, -0.0697, -0.4574, 0.0926, -0.1737, 0.1632, 0.0266],
[-0.0958, 0.0485, -0.0765, -0.4616, 0.0922, -0.1867, 0.1695, 0.0476],
[-0.0902, 0.0968, -0.0758, -0.4487, 0.0894, -0.1734, 0.1564, 0.0272],
[-0.0398, 0.0116, -0.1327, -0.4039, 0.0551, -0.2004, 0.1249, 0.1074]],
device='cuda:0')

Does that mean we can now safely avoid the workaround provided by r9y9/wavenet_vocoder

@lopsided
Copy link

I'm seeing the same problem as this (version 1.8.1). It seems that parameters and buffers are not "gathered" in the dataparallel forward like the outputs are and when I try to access them like net.mymodule.mybuffer I only see the buffer values on the cuda:0 device. In my case with 2 gpus that is only half the values.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests