-
Notifications
You must be signed in to change notification settings - Fork 22.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrong gradients when using DistributedDataParallel and autograd.grad. #47562
Comments
Note that you can get the expected behavior if you change the loss to
|
Also you can get the expected behavior if you use the following trick:
|
cc @agolynski |
@mrshenli is backward with I would suggest using torchviz https://github.com/szagoruyko/pytorchviz to see what is the graph that is created by DDP in the first backward and see what's missing there. That should point to the part of the code that is wrong and not differentiable in the DDP implementation. |
@albanD we never tested that in our tests. How does
Thanks for the suggestion, @agolynski is trying that. |
It only means that the backward actually runs with grad_mode enabled and the computed grad will require gradients. Note that for the bias grad being 0 or None, this is expected here: in the autograd engine, we don't see the difference between independent and 0 gradient and it can return either None or 0. For the second case, is the issue that you get the right gradients on each worker but they do not get averaged across devices? |
Does the sync actually happens from the Gather/Scatter or from backward hooks? |
Seems like the issue is that parameter 'b' is declared unused and DDP choked on it. Once I remove b from the model, everything works and graph is identical to 2nd graph (without x/x expression). By suggestion of @mrshenli I've tried find_unused_parameters=True in DDP but it didn't change the outcome |
From the autograd point of view, the fonction is linear in |
@TropComplique could you check if
I stand corrected (via @albanD):
is safer. |
Hi @agolynski.
But it does not suit my usecase. My usecase is R1 gradient penalty for a discriminator like here (StyleGAN2 training). |
@TropComplique if you need to update |
I think the root cause is that DDP module detects unused parameters in the prepare_for_backward call pytorch/torch/lib/c10d/reducer.cpp Line 1006 in db767b7
y = model(x) ) not when you do .backward()Since b is used in model we wait for its autograd hook callback.On the other hand, b is no longer used in dependency graph for loss and this callback never comes, so we never trigger Reducer::finalize_backward which leads to your problem.
With |
@TropComplique
it works without workarounds. Full code is here: Would it work for your usecase? |
grad = torch.autograd.grad(y.sum(), x, create_graph=True)[0]
...
loss.backward() Looking at the above code. @albanD would I be correct if I assume this will trigger the autograd hooks twice? Or |
The Tensor hooks are called every time you compute the gradient for it. So if a Tensor is used in both graphs, then the hook will be called twice yes. If a Tensor is used in only one of them, then that hook will be called only once. |
@TropComplique @agolynski in this case, if these two backward passes ( |
Also note that the DDP doc explicitely states: "This module doesn’t work with torch.autograd.grad() (i.e. it will only work if gradients are to be accumulated in .grad attributes of parameters)." |
@albanD Thanks for pointing this out! torch.autograd.grad() can potentially call autograd hooks in future (it doesn't now) and break DDP. So this is 'use at your own risk' situation. |
Hi@albanD , I have the same problem when I required second derivatives in DDP. So,how should I use |
Hi, I don't think the DDP module supports double backward either. Or at least it is not officially supported and tested for :/ |
Hi, is that possible to add DDP support for double backward and autograd.grad() in the future? |
It would require a major update on the DDP Module I am afraid. So, while it is possible, I don't think anyone is actively working towards this at the moment and so it might not happen any time soon. |
@agolynski Hello, many thanks for your information. I have met a similar problem here. As you mentioned in the above reply, DDP detects unused parameters in the forward pass. However, according to the documentation, it seems that this only occurs when we set |
We currently don't have plans to support torch.autograd.grad, but have plans to support double backwards with retain_graph=True: #47260 |
I think there may be some terminology confusion here -- double backwards usually refers to calculating second-order gradients which requires |
🐛 Bug
The gradient reduction across workers doesn't work when:
To Reproduce
The above code works as it should and outputs this:
But when I change the model to this:
The output is wrong:
Expected behavior
b
must be zero and notNone
.Environment
PyTorch version: 1.7.0+cu110
Is debug build: True
CUDA used to build PyTorch: 11.0
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.10.2
Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
GPU 2: GeForce RTX 2080 Ti
Nvidia driver version: 450.51.06
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip] numpy==1.18.2
[pip] torch==1.7.0+cu110
[pip] torchaudio==0.7.0
[pip] torchvision==0.8.1+cu110
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528 @osalpekar @jiayisuse @agolynski
The text was updated successfully, but these errors were encountered: