Skip to content

Commit 43fb39c

Browse files
rohan-varmafacebook-github-bot
authored andcommitted
[DDP] Make uneven inputs work with comm. hook (#61020)
Summary: Pull Request resolved: #61020 Makes uneven input support with `join` context manager work with custom communication hooks. This will ensure that the two features can work well together. Added relevant unittests to test allreduce and powerSGD hooks. Instead of calling `allreduce`, the join manager now calls into `_run_reduction_hook` which will automatically run whatever hook is installed. ghstack-source-id: 132950108 Test Plan: CI Reviewed By: SciPioneer Differential Revision: D29480028 fbshipit-source-id: c91dc467a62c5f1e0ec702a2944ae3deb10f93f4
1 parent 94b7306 commit 43fb39c

File tree

3 files changed

+53
-9
lines changed

3 files changed

+53
-9
lines changed

torch/csrc/distributed/c10d/init.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,15 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
399399
"_delay_all_reduce",
400400
&::c10d::Reducer::delay_all_reduce,
401401
py::call_guard<py::gil_scoped_release>())
402+
.def(
403+
"_run_comm_hook",
404+
[](::c10d::Reducer& reducer, ::c10d::GradBucket& bucket)
405+
-> std::shared_ptr<jit::PythonFutureWrapper> {
406+
c10::intrusive_ptr<c10::ivalue::Future> fut =
407+
reducer.run_comm_hook(bucket);
408+
return std::make_shared<jit::PythonFutureWrapper>(fut);
409+
},
410+
py::call_guard<py::gil_scoped_release>())
402411
.def(
403412
"set_logger",
404413
[](::c10d::Reducer& reducer,
@@ -1472,7 +1481,7 @@ Example::
14721481
.def(
14731482
"get_future",
14741483
[](::c10d::ProcessGroup::Work& work)
1475-
-> std::shared_ptr<jit::PythonFutureWrapper> {
1484+
-> std::shared_ptr<jit::PythonFutureWrapper> {
14761485
return std::make_shared<jit::PythonFutureWrapper>(work.getFuture());
14771486
},
14781487
R"(

torch/nn/parallel/distributed.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,11 +1014,11 @@ def _sync_final_model(self, is_last_joiner):
10141014
)
10151015
self._sync_params_and_buffers(authoritative_rank=self._authoritative_rank)
10161016

1017-
# Schedule allreduce ops to match those scheduled in the reducer's backward
1017+
# Schedule comm ops to match those scheduled in the reducer's backward
10181018
# pass.
10191019
def _match_all_reduce_for_bwd_pass(self):
1020-
allreduce_work = []
1021-
# Schedule allreduce in the same order as Reducer schedules them, i.e.
1020+
comm_work = []
1021+
# Schedule comm in the same order as Reducer schedules them, i.e.
10221022
# the order of the buckets. Retrieving the bucket order from the reducer
10231023
# ensures that we keep the same order in join mode, such as when bucket
10241024
# order is rebuilt dynamically.
@@ -1031,10 +1031,9 @@ def _match_all_reduce_for_bwd_pass(self):
10311031
# divide_by_initial_world_size=True, we divide grads by the static
10321032
# world size, if not, the dividing factor is reduced by the number
10331033
# of joined processes.
1034-
zero_tensor = grad_bucket.get_tensor()
1035-
work = self.process_group.allreduce(zero_tensor)
1036-
allreduce_work.append(work)
1037-
for work in allreduce_work:
1034+
work = self.reducer._run_comm_hook(grad_bucket)
1035+
comm_work.append(work)
1036+
for work in comm_work:
10381037
work.wait()
10391038

10401039
# Allreduces the used parameter mapping across ranks.

torch/testing/_internal/distributed/distributed_test.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from contextlib import contextmanager, suppress
1212
from datetime import timedelta
1313
from functools import reduce
14-
from typing import Union, NamedTuple
14+
from typing import Union, NamedTuple, Callable, Any
1515

1616
import torch
1717
import torch.cuda
@@ -183,6 +183,8 @@ class DDPUnevenTestInput(NamedTuple):
183183
inp: Union[torch.tensor, tuple]
184184
sync_interval: int
185185
throw_on_early_termination: bool = False
186+
hook: Callable = None
187+
state: Any = None
186188

187189

188190
class _FC2(nn.Module):
@@ -5384,6 +5386,11 @@ def _run_uneven_inputs_test(
53845386
bucket_cap_mb=1,
53855387
find_unused_parameters=find_unused_params,
53865388
)
5389+
# Register hook if specified
5390+
if test_case.hook is not None:
5391+
net.register_comm_hook(test_case.state, test_case.hook)
5392+
print(f"registered hook {test_case.hook}")
5393+
53875394

53885395
# Determine num iters for this rank via the passed in mapping.
53895396
num_iters = iteration_mapping[rank]
@@ -5602,6 +5609,35 @@ def forward(self, x, rank):
56025609
),
56035610
]
56045611

5612+
# Test models that have hook installed.
5613+
models_with_hook = [
5614+
DDPUnevenTestInput(
5615+
name="small_model_allreduce_hook",
5616+
model=small_model,
5617+
hook=default.allreduce_hook,
5618+
state=None,
5619+
inp=torch.ones(batch, dim, device=rank),
5620+
sync_interval=1,
5621+
),
5622+
DDPUnevenTestInput(
5623+
name="small_model_power_sgd_hook",
5624+
model=small_model,
5625+
hook=powerSGD.powerSGD_hook,
5626+
state=powerSGD.PowerSGDState(
5627+
process_group=None,
5628+
matrix_approximation_rank=1,
5629+
# Config so that powerSGD runs immediately instead of
5630+
# allreduce.
5631+
start_powerSGD_iter=1,
5632+
warm_start=False,
5633+
use_error_feedback=False,
5634+
),
5635+
inp=torch.ones(batch, dim, device=rank),
5636+
sync_interval=1,
5637+
),
5638+
]
5639+
models_to_test.extend(models_with_hook)
5640+
56055641
# Add resnet model if we have torchvision installed.
56065642
if HAS_TORCHVISION:
56075643
resnet_model = torchvision.models.resnet50()

0 commit comments

Comments
 (0)