Skip to content

Commit

Permalink
Add an option to DDP to take a list of parameters to ignore upfront. (#…
Browse files Browse the repository at this point in the history
…44826)

Summary:
Pull Request resolved: #44826

As described in #43690, there
is a need for DDP to be able to ignore certain parameters in the module (not
install allreduce hooks) for certain use cases. `find_unused_parameters` is
sufficient from a correctness perspective, but we can get better performance
with this upfront list if users know which params are unused, since we won't
have to traverse the autograd graph every iteration.

To enable this, we add a field `parameters_to_ignore` to DDP init and don't
pass in that parameter to reducer if that parameter is in the given list.
ghstack-source-id: 113210109

Test Plan: Added unittest

Reviewed By: xw285cornell, mrshenli

Differential Revision: D23740639

fbshipit-source-id: a0411712a8b0b809b9c9e6da04bef2b955ba5314
  • Loading branch information
rohan-varma authored and facebook-github-bot committed Sep 30, 2020
1 parent c112e89 commit 181afd5
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 12 deletions.
3 changes: 2 additions & 1 deletion torch/csrc/distributed/c10d/reducer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1341,8 +1341,9 @@ bool Reducer::rebuild_buckets() {
replicas_[0].size() == rebuilt_param_indices_.size(),
c10::str(
"rebuilt parameter indices size is not same as original model parameters size.",
"Original model param size is: ",
replicas_[0].size(),
" versus ",
" versus rebuilt params size of: ",
rebuilt_param_indices_.size()));
std::vector<std::vector<size_t>> rebuilt_bucket_indices;
std::vector<size_t> bucket_size_limits;
Expand Down
77 changes: 66 additions & 11 deletions torch/nn/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ def __init__(self, module, device_ids=None,
self.require_forward_param_sync = True
self.ddp_join_enabled = False
self.gradient_as_bucket_view = gradient_as_bucket_view
if hasattr(module, '_ddp_params_and_buffers_to_ignore'):
self.parameters_to_ignore = module._ddp_params_and_buffers_to_ignore
else:
self.parameters_to_ignore = []

if check_reduction:
# This argument is no longer used since the reducer
Expand All @@ -412,7 +416,11 @@ def __init__(self, module, device_ids=None,
self._ddp_init_helper()

def _sync_params_and_buffers(self, authoritative_rank=0):
module_states = list(self.module.state_dict().values())
module_states = []
for name, param in self.module.state_dict().items():
if name not in self.parameters_to_ignore:
module_states.append(param)

if len(module_states) > 0:
self._distributed_broadcast_coalesced(
module_states,
Expand Down Expand Up @@ -478,17 +486,55 @@ def model_parameters(m):
self._module_copies = [self.module]

self.modules_params = [list(parameters(m)) for m in self._module_copies]
self.modules_buffers = [list(m.buffers()) for m in self._module_copies]

# Build tuple of (module, parameter) for all parameters that require grads.
modules_and_parameters = [
# Collect buffers for modules, filtering out buffers that should be ignored.
named_module_buffers = [
[(buffer, buffer_name) for buffer_name, buffer in m.named_buffers()]
for m in self._module_copies
]
self.modules_buffers = [
[
(module, parameter)
for module in replica.modules()
for parameter in filter(
lambda parameter: parameter.requires_grad,
parameters(module, recurse=False))
] for replica in self._module_copies]
buffer
for (buffer, buffer_name) in module_buffers
if buffer_name not in self.parameters_to_ignore
]
for module_buffers in named_module_buffers
]
# Build tuple of (module, parameter) for all parameters that require grads.
if self.device_ids and len(self.device_ids) > 1:
# Single-process multi-device mode,does not support self.parameters_to_ignore.
if self.parameters_to_ignore:
raise ValueError(
"Single-Process multi-device mode does not "
"support ignoring parameters upfront. Please consider "
"using one DDP instance per device."
)

modules_and_parameters = [
[
(module, parameter)
for module in replica.modules()
for parameter in filter(
lambda parameter: parameter.requires_grad,
parameters(module, recurse=False))
] for replica in self._module_copies]
else:
modules_and_parameters = [
[
(module, parameter)
for module_name, module in replica.named_modules()
for parameter in [
param
# Note that we access module.named_parameters instead of
# parameters(module). parameters(module) is only needed in the
# single-process multi device case, where it accesses replicated
# parameters through _former_parameters.
for param_name, param in module.named_parameters(recurse=False)
if param.requires_grad
and f"{module_name}.{param_name}" not in self.parameters_to_ignore
]
]
for replica in self._module_copies
]

# Build list of parameters.
parameters = [
Expand Down Expand Up @@ -1088,3 +1134,12 @@ def _check_comm_hook(self, hook):
raise ValueError(
"Communication hook: return annotation should be torch.futures.Future or torch._C.Future."
)

@staticmethod
def _set_params_and_buffers_to_ignore_for_model(
module, params_and_buffers_to_ignore
):
# This is a workaround to set parameters and buffers DDP should ignore
# during synchronization. It will be removed when the API is finalized
# as part of addressing https://github.com/pytorch/pytorch/issues/43690.
module._ddp_params_and_buffers_to_ignore = params_and_buffers_to_ignore
88 changes: 88 additions & 0 deletions torch/testing/_internal/distributed/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3335,3 +3335,91 @@ def test_broadcast_object_list(self):
self.assertNotEqual(objects, collectives_object_test_list)
dist.broadcast_object_list(objects, src=0)
self.assertEqual(objects, collectives_object_test_list)

@require_backend({"gloo", "nccl"})
@require_backends_available({"gloo", "nccl"})
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_ddp_ignore_params_arg(self):
class TestModel(nn.Module):
def __init__(self, rank):
self.rank = rank
super(TestModel, self).__init__()
self.fc1 = nn.Linear(1, 1, bias=False)
# Proxy that will be materialized to another architecture later.
# (after wrapping model with DDP)
if self.rank == 0:
self.fc2 = nn.Linear(1, 10, bias=False)
else:
self.fc2 = nn.Linear(10, 10, bias=False)

def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x

device_id = self.rank
# Ensure the test works for both find_unused_parameter and broadcast_buffer settings.
for (find_unused, broadcast_buffers) in itertools.product([False, True], [False, True]):
model = TestModel(self.rank).float().to(device_id)
# Note that the model can have different shape buffers if we pass
# them in to be ignored as well.
model.fc2.register_buffer(
"ignore_buffer", torch.zeros(5 + self.rank, device=self.rank)
)
proxy_params = list(model.fc2.parameters())
proxy_buffers = list(model.fc2.buffers())
model_fc2_name = [
module_name
for module_name, module in model.named_modules()
if module is model.fc2
][0]
proxy_param_names = [
f"{model_fc2_name}.{param_name}"
for param_name, _ in model.fc2.named_parameters()
]
proxy_buffer_names = [
f"{model_fc2_name}.{buf_name}"
for buf_name, _ in model.fc2.named_buffers()
]
# Specify that we should ignore proxy_params since it will be
# materialized later.
torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
model, proxy_param_names + proxy_buffer_names
)
ddp = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[device_id],
find_unused_parameters=find_unused,
broadcast_buffers=broadcast_buffers,
)
# Materialize new params. These are not registered in DDP and thus
# don't have autograd hooks installed on them.
ddp.module.fc2 = nn.Linear(1, 1, bias=False).to(device_id)
# local model with the new materialized parameters.
local_model = copy.deepcopy(ddp.module).cuda(self.rank)

inp = torch.ones(1, dtype=torch.float).to(device_id) * (self.rank + 1)
for i in range(6):
ddp(inp).sum().backward()
local_model(inp).sum().backward()
# materialized param grad is not touched by DDP, so its grad should
# be the same as if running locally.
for materialized_param, local_param in zip(
ddp.module.fc2.parameters(), local_model.fc2.parameters()
):
self.assertEqual(materialized_param.grad, local_param.grad)

# fc1 parameter grad should still be different, due to allreduce.
for synced_param, local_param in zip(
ddp.module.fc1.parameters(), local_model.fc1.parameters()
):
self.assertFalse(synced_param.grad == local_param.grad)

# Proxy module grad should not be touched
for proxy_param in proxy_params:
self.assertTrue(proxy_param.grad is None)

# Synchronize since we run multiple iterations of this test, to
# isolate failure hangs.
torch.cuda.synchronize(device=self.rank)

0 comments on commit 181afd5

Please sign in to comment.