Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong gradients when using DistributedDataParallel and autograd.grad. #47562

Closed
TropComplique opened this issue Nov 7, 2020 · 27 comments
Closed
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@TropComplique
Copy link

TropComplique commented Nov 7, 2020

🐛 Bug

The gradient reduction across workers doesn't work when:

  1. gradient penalty is used as a loss,
  2. a bias is used in the last layer.

To Reproduce

import os
import torch
import torch.nn as nn

import torch.distributed as dist
import torch.multiprocessing as mp


NUM_GPUS = 2


class Model(nn.Module):

    def __init__(self):
        super(Model, self).__init__()
        self.w = nn.Parameter(torch.rand(1))
        self.b = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        return self.w * x.pow(2) + self.b * x/x  # !!! 


def worker(rank):

    torch.manual_seed(rank)
    torch.cuda.set_device(rank)
    device = torch.device(rank)
    dist.init_process_group(backend='nccl', world_size=NUM_GPUS, rank=rank)

    def parallelize(module):
        return nn.parallel.DistributedDataParallel(module, device_ids=[rank])

    model = Model().to(device)
    model = parallelize(model)

    # all workers have the same initial model
    w = model.module.w
    b = model.module.b
    print(f'initial weights at {rank}:', w.data, b.data)

    x = torch.randn(3).to(device)
    x.requires_grad = True
    y = model(x)  # shape [3]

    # all workers have different data
    print(f'input data at {rank}:', x)

    grad = torch.autograd.grad(y.sum(), x, create_graph=True)[0]
    loss = grad.pow(2).mean(0)  # gradient penalty

    # compare with gradient calculated by hand
    assert torch.isclose(2 * x * w, grad).all()

    model.zero_grad()
    loss.backward()

    # all workers have the same grad
    print(f'final gradient at {rank}:', w.grad, b.grad)

    # compare with gradient calculated by hand
    t = (8 * x.pow(2) * w).mean(0)
    print(f'local gradient at {rank}:', t)
    dist.all_reduce(t, op=dist.ReduceOp.SUM)
    assert torch.isclose(t/NUM_GPUS, w.grad).all()


def main():
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    mp.spawn(worker, nprocs=NUM_GPUS, args=())


if __name__ == '__main__':
    main()

The above code works as it should and outputs this:

initial weights at 1: tensor([0.4963], device='cuda:1') tensor([0.], device='cuda:1')
initial weights at 0: tensor([0.4963], device='cuda:0') tensor([0.], device='cuda:0')
input data at 1: tensor([ 0.1994,  2.0394, -0.2148], device='cuda:1', requires_grad=True)
input data at 0: tensor([0.2072, 0.2699, 0.5507], device='cuda:0', requires_grad=True)
final gradient at 1: tensor([3.0861], device='cuda:1') tensor([0.], device='cuda:1')
final gradient at 0: tensor([3.0861], device='cuda:0') tensor([0.], device='cuda:0')
local gradient at 1: tensor(5.6176, device='cuda:1', grad_fn=<MeanBackward1>)
local gradient at 0: tensor(0.5546, device='cuda:0', grad_fn=<MeanBackward1>)

But when I change the model to this:

class Model(nn.Module):

    def __init__(self):
        super(Model, self).__init__()
        self.w = nn.Parameter(torch.rand(1))
        self.b = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        return self.w * x.pow(2) + self.b  # x/x is removed here!

The output is wrong:

initial weights at 1: tensor([0.4963], device='cuda:1') tensor([0.], device='cuda:1')
initial weights at 0: tensor([0.4963], device='cuda:0') tensor([0.], device='cuda:0')
input data at 1: tensor([ 0.1994,  2.0394, -0.2148], device='cuda:1', requires_grad=True)
input data at 0: tensor([0.2072, 0.2699, 0.5507], device='cuda:0', requires_grad=True)
final gradient at 1: tensor([5.6176], device='cuda:1') None
final gradient at 0: tensor([0.5546], device='cuda:0') None
local gradient at 1: tensor(5.6176, device='cuda:1', grad_fn=<MeanBackward1>)
local gradient at 0: tensor(0.5546, device='cuda:0', grad_fn=<MeanBackward1>)
...
    assert torch.isclose(t/NUM_GPUS, w.grad).all()
AssertionError

Expected behavior

  1. The final gradients at each worker must be the same.
  2. Gradient for b must be zero and not None.

Environment

PyTorch version: 1.7.0+cu110
Is debug build: True
CUDA used to build PyTorch: 11.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.10.2

Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
GPU 2: GeForce RTX 2080 Ti

Nvidia driver version: 450.51.06
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip] numpy==1.18.2
[pip] torch==1.7.0+cu110
[pip] torchaudio==0.7.0
[pip] torchvision==0.8.1+cu110

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528 @osalpekar @jiayisuse @agolynski

@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Nov 7, 2020
@TropComplique
Copy link
Author

Note that you can get the expected behavior if you change the loss to

loss = grad.pow(2).mean(0) + 0.0 * y[0]

@TropComplique
Copy link
Author

Also you can get the expected behavior if you use the following trick:

grad = torch.autograd.grad((y + 10000).relu().sum(), x, create_graph=True)[0]

@mrshenli
Copy link
Contributor

cc @agolynski

@albanD
Copy link
Collaborator

albanD commented Nov 18, 2020

@mrshenli is backward with create_graph=True officially supported for DDP?

I would suggest using torchviz https://github.com/szagoruyko/pytorchviz to see what is the graph that is created by DDP in the first backward and see what's missing there. That should point to the part of the code that is wrong and not differentiable in the DDP implementation.

@mrshenli
Copy link
Contributor

@albanD we never tested that in our tests. How does create_graph affect autograd hooks? Or does it affect the autograd graph if we traverse it from the forward outputs?

I would suggest using torchviz https://github.com/szagoruyko/pytorchviz to see what is the graph that is created by DDP in the first backward and see what's missing there. That should point to the part of the code that is wrong and not differentiable in the DDP implementation.

Thanks for the suggestion, @agolynski is trying that.

@albanD
Copy link
Collaborator

albanD commented Nov 18, 2020

It only means that the backward actually runs with grad_mode enabled and the computed grad will require gradients.

Note that for the bias grad being 0 or None, this is expected here: in the autograd engine, we don't see the difference between independent and 0 gradient and it can return either None or 0.

For the second case, is the issue that you get the right gradients on each worker but they do not get averaged across devices?

@agolynski
Copy link
Contributor

Thanks @albanD, @mrshenli

Here are the graphs:
Screen Shot 2020-11-18 at 3 40 12 PM
Screen Shot 2020-11-18 at 3 39 52 PM

with and without x/x.

Yes, the issue is that gradients do not get averaged although scatter/gather backward is present in both.

@albanD
Copy link
Collaborator

albanD commented Nov 18, 2020

Does the sync actually happens from the Gather/Scatter or from backward hooks?

@agolynski
Copy link
Contributor

Seems like the issue is that parameter 'b' is declared unused and DDP choked on it. Once I remove b from the model, everything works and graph is identical to 2nd graph (without x/x expression).

By suggestion of @mrshenli I've tried find_unused_parameters=True in DDP but it didn't change the outcome

@albanD
Copy link
Collaborator

albanD commented Nov 18, 2020

From the autograd point of view, the fonction is linear in b so it is expected that b does not appear in the double backward graph.

@agolynski
Copy link
Contributor

agolynski commented Nov 18, 2020

@TropComplique could you check if

model.module.b.detach_()
works for your usecase as a workaround for now?

I stand corrected (via @albanD):

model.module.b.requires_grad_(False)

is safer.

@TropComplique
Copy link
Author

Hi @agolynski.
Where in the code do I need to detach it?
Because I tried to use your line everywhere and it did not work.
It only works if I do this:

model = Model().to(device)
model.b.requires_grad_(False)
model = parallelize(model)
...

But it does not suit my usecase.
Because I need to update parameter b in another place during a training.

My usecase is R1 gradient penalty for a discriminator like here (StyleGAN2 training).
The author of this repository uses + 0.0 * y[0] trick (like here) to trigger gradient reduction.

@agolynski
Copy link
Contributor

@TropComplique if you need to update b in some other place in your code, would it be possible to reset model.b.requires_grad_(True) before running corresponding .grad or .backward?

@agolynski
Copy link
Contributor

agolynski commented Nov 19, 2020

@TropComplique @mrshenli

I think the root cause is that DDP module detects unused parameters in the prepare_for_backward call

for (const auto& output : outputs) {
which is called from on the forward pass (when you do y = model(x)) not when you do .backward()
Since b is used in model we wait for its autograd hook callback.
On the other hand, b is no longer used in dependency graph for loss and this callback never comes, so we never trigger Reducer::finalize_backward which leads to your problem.

With *x/x, autograd isn't smart enough to declare b unused, and hence your workaround does the job.

@agolynski
Copy link
Contributor

agolynski commented Nov 19, 2020

@TropComplique
If you define GradModel which is close to your loss function, e.g.

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.w = nn.Parameter(torch.rand(1))
        self.b = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        return self.w * x.pow(2) + self.b # !!!

class GradModel(nn.Module):
    def __init__(self):
        super(GradModel, self).__init__()
        self.m = Model()

    def forward(self, x):
        y = self.m(x)
        grad = torch.autograd.grad(y.sum(), x, create_graph=True)[0]
        return grad

it works without workarounds. Full code is here:
https://pastebin.com/pxnmv6nZ

Would it work for your usecase?

@mrshenli
Copy link
Contributor

    grad = torch.autograd.grad(y.sum(), x, create_graph=True)[0]
    ...
    loss.backward()

Looking at the above code. @albanD would I be correct if I assume this will trigger the autograd hooks twice? Or autograd.grad won't trigger hooks at all?

@albanD
Copy link
Collaborator

albanD commented Nov 19, 2020

The Tensor hooks are called every time you compute the gradient for it. So if a Tensor is used in both graphs, then the hook will be called twice yes. If a Tensor is used in only one of them, then that hook will be called only once.

@mrshenli
Copy link
Contributor

The Tensor hooks are called every time you compute the gradient for it. So if a Tensor is used in both graphs, then the hook will be called twice yes. If a Tensor is used in only one of them, then that hook will be called only once.

@TropComplique @agolynski in this case, if these two backward passes (.grad() and .backward) share any parameters, DDP won't work properly. Because DDP expects forward and backward passes to run alternatively, otherwise hooks installed by DDP won't be handled correctly.

@albanD
Copy link
Collaborator

albanD commented Nov 19, 2020

Also note that the DDP doc explicitely states: "This module doesn’t work with torch.autograd.grad() (i.e. it will only work if gradients are to be accumulated in .grad attributes of parameters)."

@agolynski
Copy link
Contributor

@albanD Thanks for pointing this out! torch.autograd.grad() can potentially call autograd hooks in future (it doesn't now) and break DDP. So this is 'use at your own risk' situation.

@xmty777
Copy link

xmty777 commented Nov 26, 2020

Hi@albanD , I have the same problem when I required second derivatives in DDP.
When I use loss.backward(create_graph=True) in DistributedDataParallel . It causes memory leak . (It is ok when I don't use DistributedDataParallel)
Then I find #3818 and change backward to autograd.grad .
Through memory leak disappeared , it seems that torch.autograd.grad() can't call the autograd hook.

So,how should I use create_graph ?

@albanD
Copy link
Collaborator

albanD commented Nov 30, 2020

Hi,

I don't think the DDP module supports double backward either. Or at least it is not officially supported and tested for :/

@devzhk
Copy link

devzhk commented Dec 16, 2020

Hi, is that possible to add DDP support for double backward and autograd.grad() in the future?

@albanD
Copy link
Collaborator

albanD commented Dec 16, 2020

It would require a major update on the DDP Module I am afraid. So, while it is possible, I don't think anyone is actively working towards this at the moment and so it might not happen any time soon.

@yelantf
Copy link

yelantf commented Jan 22, 2021

@TropComplique @mrshenli

I think the root cause is that DDP module detects unused parameters in the prepare_for_backward call

for (const auto& output : outputs) {

which is called from on the forward pass (when you do y = model(x)) not when you do .backward()
Since b is used in model we wait for its autograd hook callback.
On the other hand, b is no longer used in dependency graph for loss and this callback never comes, so we never trigger Reducer::finalize_backward which leads to your problem.
With *x/x, autograd isn't smart enough to declare b unused, and hence your workaround does the job.

@agolynski Hello, many thanks for your information. I have met a similar problem here. As you mentioned in the above reply, DDP detects unused parameters in the forward pass. However, according to the documentation, it seems that this only occurs when we set find_unused_parameters=True, but this issue happens even we set find_unused_parameters=False(as the author of this issue states). So, when we set find_unused_parameters=False, what could be the reason that gradients of those parameters, which is used outer of the forward function, are not averaged by the DDP? Many thanks for your help in advance.

@rohan-varma
Copy link
Member

We currently don't have plans to support torch.autograd.grad, but have plans to support double backwards with retain_graph=True: #47260

@gchanan
Copy link
Contributor

gchanan commented Jul 22, 2021

I think there may be some terminology confusion here -- double backwards usually refers to calculating second-order gradients which requires torch.autograd.grad (or other higher-level API), not calling backwards multiple times with retain_graph

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

No branches or pull requests

10 participants