Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward/backward hooks support in DistributedDataParallel #35191

Closed
h6197627 opened this issue Mar 22, 2020 · 5 comments
Closed

Forward/backward hooks support in DistributedDataParallel #35191

h6197627 opened this issue Mar 22, 2020 · 5 comments
Labels
feature A request for a proper, new feature. module: data parallel triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@h6197627
Copy link
Contributor

h6197627 commented Mar 22, 2020

🚀 Feature

PyTorch now recommends to use DistributedDataParallel over DataParallel for all sorts of multi-GPU trainings (#35063). However, it has one limitation comparing to old DataParallel module - currently it cannot handle forward/backward hooks in a user convenient way.
Proposed workaround

.. warning::
Forward and backward hooks defined on :attr:`module` and its submodules
won't be invoked anymore, unless the hooks are initialized in the
:meth:`forward` method.

requires users to edit each model's forward propagation code for using hooks with model wrapped into DDP.
As I understand, it wasn't initially designed having this limitation in mind and was discovered during fixing another issue #5061. So, I am wondering, maybe there are some possibilities to implement some sort of hook synchronization mechanism across distributed model replicas?

Motivation

Also with current workaround possibilities to use hooks dynamically is lost for DistributedDataParallel module. For example, in my current code with DataParallel I am able to place and remove hooks dynamically: during validation phase of training process I am placing hooks to extract additional bottleneck features to calculate some complementary evaluation metrics which are not calculated during training phase.
In general, current hooking mechanism looks not fully compatible with DDP.

Pitch

Hooking mechanism for DistributedDataParallel module working from the user perspective as in DataParallel module.

@pbelevich pbelevich added feature A request for a proper, new feature. module: data parallel triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module triage review and removed triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Mar 23, 2020
@ailzhang ailzhang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 30, 2020
@jbojar
Copy link

jbojar commented Sep 29, 2020

I think this warning about forward and backward hooks in DistributedDataParallel is valid only if model is replicated by DistributedDataParallel.replicate method. And this can happen probably only when using multiple GPUs in single process.

In single-gpu-per-process mode whole model is created from scratch on every node and forward and backward hooks work just fine (we didn't experienced any problems with hooks in such setup).

@h6197627
Copy link
Contributor Author

h6197627 commented Nov 3, 2020

Hi, @jbojar,
I really need to try it more thoroughly. I remember that at the time when DDP was introduced as better to use alternative to DataParallel (quite long ago) I was able to reproduce this limiting behavior, but now after your comment I tried simple script and it looks like it works

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

class Model(nn.Module):
	def __init__(self, input_dim=(32, 32, 3)):
		super(Model, self).__init__()
		assert (len(input_dim) == 3)
		conv_channels = 16
		self.conv = nn.Conv2d(in_channels=input_dim[-1], out_channels=conv_channels, kernel_size=3, stride=1, padding=1)
		self.linear = nn.Linear(in_features=input_dim[0]*input_dim[1]*conv_channels, out_features=128)
		nn.init.kaiming_uniform_(self.conv.weight, a=1)
		nn.init.constant_(self.conv.bias, 0)
		nn.init.kaiming_uniform_(self.linear.weight, a=1)
		nn.init.constant_(self.linear.bias, 0)

	def forward(self, input):
		print('Forward: {}'.format(torch.distributed.get_rank()))
		out = self.conv(input)
		out = out.view(out.size(0), -1)
		out = self.linear(out)
		return out


def test_ddp_fn(rank, world_size, hook_before):
	setup(rank, world_size)

	model = Model().to(rank)
	# Dummy hook function imitating ReLU activation function
	def relu_act(module, input):
		print('Hook: {}'.format(torch.distributed.get_rank()))
		return F.relu(input[0])
	
	if hook_before:
		model.linear.register_forward_pre_hook(lambda module, input: relu_act(module, input))
		print('Hook registered before wrapping with DDP: {}'.format(torch.distributed.get_rank()))
	model = DDP(model, device_ids=[rank])
	if not hook_before:
		model.module.linear.register_forward_pre_hook(lambda module, input: relu_act(module, input))
		print('Hook registered after wrapping with DDP: {}'.format(torch.distributed.get_rank()))

	input_data = torch.rand(1, 3, 32, 32)
	model(input_data)

	cleanup()


def setup(rank, world_size):
	os.environ['MASTER_ADDR'] = '127.0.0.1'
	os.environ['MASTER_PORT'] = '29500'
	dist.init_process_group(backend='gloo', rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()


if __name__ == "__main__":
	world_size = 2
	hook_before = True
	mp.spawn(test_ddp_fn, args=(world_size, hook_before), nprocs=world_size, join=True)

Prints

Hook registered before wrapping with DDP: 1
Hook registered before wrapping with DDP: 0
Forward: 0
Forward: 1
Hook: 1
Hook: 0

@jbojar
Copy link

jbojar commented Nov 4, 2020 via email

@aluo-x
Copy link

aluo-x commented Aug 25, 2021

I'm actually quite confused. The language is still present in the documentation:

Forward and backward hooks defined on module and its submodules won’t be invoked anymore, unless the hooks are initialized in the forward() method.

So what is the correct way to register forward/backwards hooks when using DDP?

@h6197627
Copy link
Contributor Author

Closed in #74063

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: data parallel triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants