Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DDP] multiple forward support for static graph #103487

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/distributed/test_distributed_spawn.py
Expand Up @@ -34,6 +34,8 @@ def setUp(self):
super().setUp()
self._spawn_processes()
torch.backends.cudnn.flags(enabled=True, allow_tf32=False).__enter__()
else:
print(f"Invalid backend {BACKEND}. Tests will not be run!")


if __name__ == "__main__":
Expand Down
12 changes: 10 additions & 2 deletions torch/csrc/distributed/c10d/reducer.cpp
Expand Up @@ -109,6 +109,8 @@ Reducer::Reducer(
gradient_as_bucket_view_(gradient_as_bucket_view),
local_used_map_reduced_(false),
num_iterations_(0),
num_bwd_calls_(0),
first_autograd_hook_called_(false),
num_buckets_ready_(0),
has_rebuilt_bucket_(false),
bucket_bytes_cap_(bucket_bytes_cap),
Expand Down Expand Up @@ -267,11 +269,11 @@ bool Reducer::dynamic_graph_find_unused() {
}

bool Reducer::static_graph_first_iteration() {
return static_graph_ && num_iterations_ == 1;
return static_graph_ && num_bwd_calls_ == 1;
}

bool Reducer::static_graph_after_first_iteration() {
return static_graph_ && num_iterations_ > 1;
return static_graph_ && num_bwd_calls_ > 1;
}

bool Reducer::ddp_graph_static() {
Expand Down Expand Up @@ -613,6 +615,10 @@ void Reducer::set_logger(std::weak_ptr<c10d::Logger> logger) {
// This function is only to be called from the autograd thread.
void Reducer::autograd_hook(size_t index) {
std::lock_guard<std::mutex> lock(this->mutex_);
if (!first_autograd_hook_called_) {
first_autograd_hook_called_ = true;
num_bwd_calls_++;
awgu marked this conversation as resolved.
Show resolved Hide resolved
}
// Ignore if we don't expect to be called.
// This may be the case if the user wants to accumulate gradients
// for number of iterations before reducing them.
Expand Down Expand Up @@ -1520,6 +1526,8 @@ void Reducer::finalize_backward() {
// No longer expect autograd hooks to fire after this function returns.
TORCH_INTERNAL_ASSERT(expect_autograd_hooks_);
expect_autograd_hooks_ = false;
// reset for the next iteration
first_autograd_hook_called_ = false;

// No longer require call to finalize after this function returns.
TORCH_INTERNAL_ASSERT(require_finalize_);
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/distributed/c10d/reducer.hpp
Expand Up @@ -404,6 +404,11 @@ class TORCH_API Reducer {

// track the number of iterations to synchronize grads in training so far.
long num_iterations_;
// track distinct iteration of backward call. This is distinct from num_iterations_,
// for example in the case of multiple forward before backward.
long num_bwd_calls_;
// whether the first autograd hook for a distinct backward pass has been called.
bool first_autograd_hook_called_;
// track the number of buckets that have been ready for
// communication calls like allReduce or communication hooks.
int num_buckets_ready_;
Expand Down
12 changes: 7 additions & 5 deletions torch/nn/parallel/distributed.py
Expand Up @@ -256,11 +256,14 @@ def backward(ctx, *grad_outputs):
ddp_weakref = ctx.ddp_weakref()
reducer = ddp_weakref.reducer
static_graph = ddp_weakref.static_graph
num_forward_calls = ddp_weakref.num_forward_calls
if static_graph and num_forward_calls == 1:
delay_ar_enqueued = (
static_graph and ddp_weakref._static_graph_delay_allreduce_enqueued
)
if static_graph and not delay_ar_enqueued:
Variable._execution_engine.queue_callback( # type: ignore[call-arg,misc]
reducer._delay_all_reduce
)
ddp_weakref._static_graph_delay_allreduce_enqueued = True

return (None, *grad_outputs)

Expand Down Expand Up @@ -1050,7 +1053,6 @@ def _ddp_init_helper(
(4) Logging construction-time DDP logging data
(5) passing a handle of DDP to SyncBatchNorm Layer
"""
self.num_forward_calls = 0
# Notice, the parameters order is not in the order in which they are used,
# especially in models with control flow.
#
Expand Down Expand Up @@ -1384,7 +1386,6 @@ def _pre_forward(self, *inputs, **kwargs):
if torch.is_grad_enabled() and self.require_backward_grad_sync:
assert self.logger is not None
self.logger.set_runtime_stats_and_log()
self.num_forward_calls += 1
self.reducer.prepare_for_forward()

# Notify the join context that this process has not joined, if
Expand Down Expand Up @@ -1469,7 +1470,7 @@ def _post_forward(self, output):
# TODO: DDPSink is currently enabled for unused parameter detection and
# static graph training for first iteration.
if (self.find_unused_parameters and not self.static_graph) or (
self.static_graph and self.num_forward_calls == 1
self.static_graph and not self._static_graph_delay_allreduce_enqueued
):
(
output_tensor_list,
Expand Down Expand Up @@ -2201,6 +2202,7 @@ def _set_static_graph(self):
)
return
self.static_graph = True
self._static_graph_delay_allreduce_enqueued = False
self.reducer._set_static_graph()
assert self.logger is not None
self.logger._set_static_graph()
Expand Down
56 changes: 56 additions & 0 deletions torch/testing/_internal/distributed/distributed_test.py
Expand Up @@ -9769,6 +9769,62 @@ def forward(self, x):
for buf in bufs[1:]:
self.assertEqual(rank_0_buf, buf)

@skip_if_lt_x_gpu(2)
@skip_but_pass_in_sandcastle_if(
BACKEND != "nccl" and BACKEND != "gloo",
"Only Nccl & Gloo backend support DistributedDataParallel",
)
def test_static_graph_multi_forward(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(10, 10)
self.relu = nn.ReLU()

def forward(self, x):
return self.relu(self.lin(x))

torch.cuda.set_device(self.rank)
torch.manual_seed(42 << 1337 % (self.rank + 1))
model = Net().cuda(self.rank)
local_model = copy.deepcopy(model)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[self.rank], static_graph=True
)
inp = torch.ones(2, 10, device="cuda")
for _ in range(3):
model.zero_grad()
local_model.zero_grad()
a = model(inp)
b = model(inp)
loss = a.sum() + b.sum()
loss.backward()
# Grads should be equal to a local model that ran through inp twice and averaged grads
if self.rank == 0:
inp_clone = inp.clone()
for _ in range(2):
a = local_model(inp_clone)
b = local_model(inp_clone)
loss = a.sum() + b.sum()
loss.backward()

ws = dist.get_world_size()
for p in local_model.parameters():
p.grad.data = p.grad / dist.get_world_size()

for p_ddp, p_local in zip(
model.parameters(),
local_model.parameters()
):
self.assertTrue(
torch.allclose(
p_ddp.grad, p_local.grad
),
f"{p_ddp.grad} vs {p_local.grad}"
)

dist.barrier()

@skip_if_lt_x_gpu(2)
@skip_but_pass_in_sandcastle_if(
BACKEND != "nccl" and BACKEND != "gloo",
Expand Down