Skip to content

Commit

Permalink
Revert "[DDP] multiple forward support for static graph (#103487)" (#…
Browse files Browse the repository at this point in the history
…103873)

Per the discussion in #103629 (comment), I preemptively create this revert PR to revert all commits in the stack.  This seems like a safer option than using the bot as the commit has already been in trunk since last week.
Pull Request resolved: #103873
Approved by: https://github.com/rohan-varma
  • Loading branch information
huydhn authored and pytorchmergebot committed Jun 20, 2023
1 parent 7b6dc72 commit b1ddd5a
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 88 deletions.
2 changes: 0 additions & 2 deletions test/distributed/test_distributed_spawn.py
Expand Up @@ -34,8 +34,6 @@ def setUp(self):
super().setUp()
self._spawn_processes()
torch.backends.cudnn.flags(enabled=True, allow_tf32=False).__enter__()
else:
print(f"Invalid backend {BACKEND}. Tests will not be run!")


if __name__ == "__main__":
Expand Down
12 changes: 2 additions & 10 deletions torch/csrc/distributed/c10d/reducer.cpp
Expand Up @@ -109,8 +109,6 @@ Reducer::Reducer(
gradient_as_bucket_view_(gradient_as_bucket_view),
local_used_map_reduced_(false),
num_iterations_(0),
num_bwd_calls_(0),
first_autograd_hook_called_(false),
num_buckets_ready_(0),
has_rebuilt_bucket_(false),
bucket_bytes_cap_(bucket_bytes_cap),
Expand Down Expand Up @@ -269,11 +267,11 @@ bool Reducer::dynamic_graph_find_unused() {
}

bool Reducer::static_graph_first_iteration() {
return static_graph_ && num_bwd_calls_ == 1;
return static_graph_ && num_iterations_ == 1;
}

bool Reducer::static_graph_after_first_iteration() {
return static_graph_ && num_bwd_calls_ > 1;
return static_graph_ && num_iterations_ > 1;
}

bool Reducer::ddp_graph_static() {
Expand Down Expand Up @@ -615,10 +613,6 @@ void Reducer::set_logger(std::weak_ptr<c10d::Logger> logger) {
// This function is only to be called from the autograd thread.
void Reducer::autograd_hook(size_t index) {
std::lock_guard<std::mutex> lock(this->mutex_);
if (!first_autograd_hook_called_) {
first_autograd_hook_called_ = true;
num_bwd_calls_++;
}
// Ignore if we don't expect to be called.
// This may be the case if the user wants to accumulate gradients
// for number of iterations before reducing them.
Expand Down Expand Up @@ -1526,8 +1520,6 @@ void Reducer::finalize_backward() {
// No longer expect autograd hooks to fire after this function returns.
TORCH_INTERNAL_ASSERT(expect_autograd_hooks_);
expect_autograd_hooks_ = false;
// reset for the next iteration
first_autograd_hook_called_ = false;

// No longer require call to finalize after this function returns.
TORCH_INTERNAL_ASSERT(require_finalize_);
Expand Down
5 changes: 0 additions & 5 deletions torch/csrc/distributed/c10d/reducer.hpp
Expand Up @@ -404,11 +404,6 @@ class TORCH_API Reducer {

// track the number of iterations to synchronize grads in training so far.
long num_iterations_;
// track distinct iteration of backward call. This is distinct from num_iterations_,
// for example in the case of multiple forward before backward.
long num_bwd_calls_;
// whether the first autograd hook for a distinct backward pass has been called.
bool first_autograd_hook_called_;
// track the number of buckets that have been ready for
// communication calls like allReduce or communication hooks.
int num_buckets_ready_;
Expand Down
31 changes: 16 additions & 15 deletions torch/nn/parallel/distributed.py
Expand Up @@ -239,11 +239,12 @@ class _BufferCommHook:
# is completed.
class _DDPSink(Function):
@staticmethod
def forward(ctx, ddp_weakref, *inputs):
def forward(ctx, reducer, state_dict, *inputs):
# set_materialize_grads(False) will ensure that None gradients stay as
# None and are not filled with zeros.
ctx.set_materialize_grads(False)
ctx.ddp_weakref = ddp_weakref
ctx.reducer = reducer
ctx.state_dict = state_dict
ret = tuple(
inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs
)
Expand All @@ -253,19 +254,12 @@ def forward(ctx, ddp_weakref, *inputs):
def backward(ctx, *grad_outputs):
# Enqueue delay allreduce for static graph training on the first
# iteration.
ddp_weakref = ctx.ddp_weakref()
reducer = ddp_weakref.reducer
static_graph = ddp_weakref.static_graph
delay_ar_enqueued = (
static_graph and ddp_weakref._static_graph_delay_allreduce_enqueued
)
if static_graph and not delay_ar_enqueued:
if ctx.state_dict["static_graph"] and ctx.state_dict["num_iterations"] == 1:
Variable._execution_engine.queue_callback( # type: ignore[call-arg,misc]
reducer._delay_all_reduce
ctx.reducer._delay_all_reduce
)
ddp_weakref._static_graph_delay_allreduce_enqueued = True

return (None, *grad_outputs)
return (None, None, *grad_outputs)


class _DDPJoinHook(JoinHook):
Expand Down Expand Up @@ -1053,6 +1047,7 @@ def _ddp_init_helper(
(4) Logging construction-time DDP logging data
(5) passing a handle of DDP to SyncBatchNorm Layer
"""
self.num_iterations = 0
# Notice, the parameters order is not in the order in which they are used,
# especially in models with control flow.
#
Expand Down Expand Up @@ -1386,6 +1381,7 @@ def _pre_forward(self, *inputs, **kwargs):
if torch.is_grad_enabled() and self.require_backward_grad_sync:
assert self.logger is not None
self.logger.set_runtime_stats_and_log()
self.num_iterations += 1
self.reducer.prepare_for_forward()

# Notify the join context that this process has not joined, if
Expand Down Expand Up @@ -1470,8 +1466,13 @@ def _post_forward(self, output):
# TODO: DDPSink is currently enabled for unused parameter detection and
# static graph training for first iteration.
if (self.find_unused_parameters and not self.static_graph) or (
self.static_graph and not self._static_graph_delay_allreduce_enqueued
self.static_graph and self.num_iterations == 1
):
state_dict = {
"static_graph": self.static_graph,
"num_iterations": self.num_iterations,
}

(
output_tensor_list,
treespec,
Expand All @@ -1490,7 +1491,8 @@ def _post_forward(self, output):
# undefined gradient which the reducer then handles to ensure
# param.grad field is not touched and we don't error out.
passthrough_tensor_list = _DDPSink.apply(
weakref.ref(self),
self.reducer,
state_dict,
*output_tensor_list,
)
for i in range(len(output_placeholders)):
Expand Down Expand Up @@ -2202,7 +2204,6 @@ def _set_static_graph(self):
)
return
self.static_graph = True
self._static_graph_delay_allreduce_enqueued = False
self.reducer._set_static_graph()
assert self.logger is not None
self.logger._set_static_graph()
Expand Down
56 changes: 0 additions & 56 deletions torch/testing/_internal/distributed/distributed_test.py
Expand Up @@ -9810,62 +9810,6 @@ def forward(self, x):
for buf in bufs[1:]:
self.assertEqual(rank_0_buf, buf)

@skip_if_lt_x_gpu(2)
@skip_but_pass_in_sandcastle_if(
BACKEND != "nccl" and BACKEND != "gloo",
"Only Nccl & Gloo backend support DistributedDataParallel",
)
def test_static_graph_multi_forward(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(10, 10)
self.relu = nn.ReLU()

def forward(self, x):
return self.relu(self.lin(x))

torch.cuda.set_device(self.rank)
torch.manual_seed(42 << 1337 % (self.rank + 1))
model = Net().cuda(self.rank)
local_model = copy.deepcopy(model)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[self.rank], static_graph=True
)
inp = torch.ones(2, 10, device="cuda")
for _ in range(3):
model.zero_grad()
local_model.zero_grad()
a = model(inp)
b = model(inp)
loss = a.sum() + b.sum()
loss.backward()
# Grads should be equal to a local model that ran through inp twice and averaged grads
if self.rank == 0:
inp_clone = inp.clone()
for _ in range(2):
a = local_model(inp_clone)
b = local_model(inp_clone)
loss = a.sum() + b.sum()
loss.backward()

ws = dist.get_world_size()
for p in local_model.parameters():
p.grad.data = p.grad / dist.get_world_size()

for p_ddp, p_local in zip(
model.parameters(),
local_model.parameters()
):
self.assertTrue(
torch.allclose(
p_ddp.grad, p_local.grad
),
f"{p_ddp.grad} vs {p_local.grad}"
)

dist.barrier()

@skip_if_lt_x_gpu(2)
@skip_but_pass_in_sandcastle_if(
BACKEND != "nccl" and BACKEND != "gloo",
Expand Down

0 comments on commit b1ddd5a

Please sign in to comment.