-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Nn.dataparallel with multiple output, weird gradient result None #15716
Comments
I think this issue is the same reason v1.0.0 nn.utils.weight_norm seems to nullify gradients of unrelated parameters if wrapped in DataParallel |
Agree with:
I forgot to mention that if I change the return of the forward() call of NN4 from |
I'm also running into this issue where if I compute the loss based on tensors returned in a Are there any know workarounds for this? EDIT: Seems like someone came up with a workaround here: r9y9/wavenet_vocoder@6b9c932 Here's my workaround for the module version based on the workaround from above. class DataParallelFix(nn.DataParallel):
"""
Temporary workaround for https://github.com/pytorch/pytorch/issues/15716.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._replicas = None
self._outputs = None
def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError(
"module must have its parameters and buffers "
"on device {} (device_ids[0]) but found one of "
"them on device: {}".format(self.src_device_obj,
t.device))
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
self._replicas = self.replicate(self.module,
self.device_ids[:len(inputs)])
self._outputs = self.parallel_apply(self._replicas, inputs, kwargs)
return self.gather(self._outputs, self.output_device) This just keeps replicas and outputs in memory. |
Was this fixed on the latest release without having to use the workaround posted? |
I tried eric-zhenhai test code on Pytorch==1.4.0 on Ubuntu with 3 GPUs and this is the output I get
Does that mean we can now safely avoid the workaround provided by r9y9/wavenet_vocoder |
I'm seeing the same problem as this (version 1.8.1). It seems that parameters and buffers are not "gathered" in the dataparallel forward like the outputs are and when I try to access them like |
馃悰 Bug
Under PyTorch 1.0, nn.DataParallel() wrapper for models with multiple outputs does not calculate gradients properly.
To Reproduce
On servers with >=2 GPUs, under PyTorch 1.0.0
Steps to reproduce the behavior:
Then run the code with "CUDA_VISIBLE_DEVICES=0,1 python dp_test.py" in console. Under PyTorch 1.0.0, I get:
However, under PyTorch 0.4.0, I get:
Expected behavior
aaa[0][1].grad
should not be none under PyTorch 1.0.0Environment
PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: Ubuntu 16.04.5 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.5.1
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: Tesla V100-SXM2-16GB
GPU 1: Tesla V100-SXM2-16GB
GPU 2: Tesla V100-SXM2-16GB
GPU 3: Tesla V100-SXM2-16GB
Nvidia driver version: 396.44
cuDNN version: Probably one of the following:
/usr/local/cuda-8.0/lib64/libcudnn.so.6.0.21
/usr/local/cuda-8.0/lib64/libcudnn_static.a
/usr/local/cuda-9.0/lib64/libcudnn.so.7.3.1
/usr/local/cuda-9.0/lib64/libcudnn_static.a
/usr/local/cuda-9.1/lib64/libcudnn.so.7.0.5
/usr/local/cuda-9.1/lib64/libcudnn_static.a
/usr/local/cuda-9.2/lib64/libcudnn.so.7.3.1
/usr/local/cuda-9.2/lib64/libcudnn_static.a
Additional context
The text was updated successfully, but these errors were encountered: