Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot update part of the parameters in DistributedDataParallel. #22049

Open
fuzihaofzh opened this issue Jun 20, 2019 · 15 comments
Open

Cannot update part of the parameters in DistributedDataParallel. #22049

fuzihaofzh opened this issue Jun 20, 2019 · 15 comments
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@fuzihaofzh
Copy link

fuzihaofzh commented Jun 20, 2019

🐛 Bug

When I use multiple GPU while the loss is calculated by only part of the parameters. I get the following errors. Use only one GPU works well.

To Reproduce

Steps to reproduce the behavior:

Define a network in which the loss only depends on part of the parameters. We get:

RuntimeError: Expected to have finished reduction in the prior iteration before starting a
 new one. This error indicates that your module has parameters that were not used in produ
cing loss. You can enable unused parameter detection by (1) passing the keyword argument `
find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; (2) making su
re all `forward` function outputs participate in calculating loss. If you already have done
the above two steps, then the distributed data parallel module wasn't able to locate the
 output tensors in the return value of your module's `forward` function. Please include th
e loss function and the structure of the return value of `forward` of your module when rep
orting this issue (e.g. list, dict, iterable). (prepare_for_backward at /pytorch/torch/csrc/distributed/c10d/reducer.cpp:429)

Expected behavior

Environment

PyTorch version: 1.2.0.dev20190620
CUDA used to build PyTorch: 9.0.176
OS: CentOS Linux release 7.5.1804 (Core)
GCC version: (crosstool-NG 1.23.0.449-a04d0) 7.3.0
CMake version: version 2.8.12.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce GTX 1080
GPU 1: GeForce GTX 1080
GPU 2: GeForce GTX 1080
GPU 3: GeForce GTX 1080

Nvidia driver version: 396.26
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.3.2
[pip3] numpy==1.15.4
[pip3] pytorch-pretrained-bert==0.4.0
[pip3] torch==1.0.1.post2
[pip3] torchfile==0.1.0
[pip3] torchtext==0.4.0
[pip3] torchvision-nightly==0.2.1
[conda] pytorch-pretrained-bert 0.6.2 pypi_0 pypi
[conda] torch-nightly 1.2.0.dev20190620 pypi_0 pypi
[conda] torchfile 0.1.0 pypi_0 pypi
[conda] torchtext 0.4.0 pypi_0 pypi

@pytorchbot pytorchbot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jun 20, 2019
@pietern
Copy link
Contributor

pietern commented Jun 21, 2019

Did you try the instructions in the error message?

@pietern pietern added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 21, 2019
@fuzihaofzh
Copy link
Author

fuzihaofzh commented Jun 21, 2019

@pietern The instruction there just give a way to find which variable is not included. However, Since I want to train the first part of the net first. I intend not the update the parameters for the second net.

@pietern
Copy link
Contributor

pietern commented Jun 25, 2019

How do you freeze the parameters you don't want to train?

If you set param.requires_grad = False before wrapping the model with torch.nn.parallel.DistributedDataParallel it should work. If you set this after wrapping then DDP will still expect gradients for those parameters and give the error you posted.

@fuzihaofzh
Copy link
Author

@pietern Hi, Thanks for your answer. I want to train two network alternately, so, It is set dynamically after DDP. I think DDP should have some functions to dynamically freeze some part of the network, I feel this is a commonly used function.

@pietern
Copy link
Contributor

pietern commented Jun 26, 2019

Just to make sure I understand correctly:

  • You have 2 models that are each individually wrapped with DDP
  • You want to train them alternately
  • You do so by setting all model parameters to requires_grad = False

For the last one, do you freeze ALL parameters or only a subset? I believe that you can freeze the whole model today, and it should work out of the box. Only if you freeze a subset of the model will you get the error message you posted.

@fuzihaofzh
Copy link
Author

I have 2 models but I wrapped them in one DDP. This may be the problem. Is it possible to wrap all models in one DDP and dynamically freeze the parameters? I think this may make thing much easier?

@pietern
Copy link
Contributor

pietern commented Jun 27, 2019

It is not possible today to partially freeze a DDP wrapped model. Either you freeze the whole thing (and no model parameter receives gradients), or none at all. If you want to alternate between two models, it is best to wrap them separately and freeze them entirely, separately, as well.

@xf3227
Copy link

xf3227 commented Oct 21, 2019

I was recently encountering the same problem. I guess PyTorch or the backend library is implemented in such manner due to the synchronization issue. My walk-around is to set the gradients of all 'freezed' parameters to zeros, right after calling <any_loss>.backward(). This solution is not necessary to be perfect for those non-stateless optimizers (e.g. SGD with momentum, Adam etc.), but the reduction error will no longer be triggered. Hope my solution will be helpful.

@pietern
Copy link
Contributor

pietern commented Oct 24, 2019

@xf3227 Did you try the fix that the error message suggests (find_unused_parameters=True)?

If you freeze a subset of parameters, there is currently no way for DDP to know if the same set is frozen across all processes. Therefore, the parameters that don't receive gradients will be made to contribute zeroes, and the reduction is executed as expected. If the parameters are frozen on all processes, the reduced gradient should be all zeroes on all processes (assuming you have called zero_grads yourself before starting the next iteration).

@xf3227
Copy link

xf3227 commented Oct 24, 2019

@pietern Thank you for pointing that out. I guess I somehow misstated my idea. What I was trying doing was not to detach anything on the fly. After calling <any_loss>.backward() and before <any_optimizer>.step(), we have a chance to manually modify the gradients. Based on @fuzihaofzh‘s description, he wanted to train two models alternatively, then he could just replace the gradients of one model by zeros. <any_stateless_optimizer>.step() will run as usual, but no change is going to be made since the gradients are zeros.

Based on my knowledge, if I'm right, gradient reduction (or synchronization) happens during <any_loss>.backward(), anything coming after is processed independently on each GPU, so I believe it's safe to manipulate the gradients.

Please let me know if I misunderstand any point. I will really appreciate that.

@pietern
Copy link
Contributor

pietern commented Oct 24, 2019

It's safe, but it's better to not synchronize at all if you don't have to.

@yinghuang
Copy link

I met the same issue.
But i solved it.
The reason is that in my model class, I define a fpn module with 5 level output feature maps in the init function,
but in forward function I only use 4 of them.
When I use all of them, the problem was solved.
This is my supposed conclusion: you should use all output of each module in forward function.

@lim142857
Copy link

Did we solve this issue? Specifically, to partially freeze a DDP-wrapped model?

@Zod-L
Copy link

Zod-L commented Mar 6, 2023

In March 2023, this is still a problem unsolved

@francescotaioli
Copy link

I was having issues while fine-tuning some layers of a pre-trained model with DDP.
Specifically, I froze some layers (note: after wrapping my model with DDP) and updated others, but I was receiving this error.

AS @pietern suggested, setting param.requires_grad = False before wrapping the model with DDP solved the issue. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

8 participants