Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DDP] Perform input casting in pre forward #100131

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 32 additions & 31 deletions torch/nn/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,34 +1369,8 @@ def _inside_ddp_forward(self):
DistributedDataParallel._active_ddp_module = None

def _run_ddp_forward(self, *inputs, **kwargs):
if self.device_ids:
inputs, kwargs = _to_kwargs(
inputs,
kwargs,
torch.device(self.device_type, self.device_ids[0]),
self.use_side_stream_for_tensor_copies,
)
args, kwargs = inputs[0], kwargs[0] # type: ignore[index]
# Cast inputs to reduced precision if needed.
if self.mixed_precision is not None:
args, kwargs = _cast_forward_inputs(
self.mixed_precision.param_dtype,
*args,
**kwargs,
)
with self._inside_ddp_forward():
return self.module(*args, **kwargs) # type: ignore[index]
else:
# Cast inputs to reduced precision if needed.
# TODO (rohan-varma) test this codepath.
if self.mixed_precision is not None:
inputs, kwargs = _cast_forward_inputs(
self.mixed_precision.param_dtype,
*inputs,
**kwargs,
)
with self._inside_ddp_forward():
return self.module(*inputs, **kwargs)
with self._inside_ddp_forward():
return self.module(*inputs, **kwargs) # type: ignore[index]

def _clear_grad_buffer(self):
# Making param.grad points to the grad buffers before backward is based on the
Expand All @@ -1419,9 +1393,9 @@ def _clear_grad_buffer(self):
if all_param_grad_none:
self._delay_grad_buffer.zero_()

def _pre_forward(self):
def _pre_forward(self, *inputs, **kwargs):
if self._delay_all_reduce_all_params:
return
return inputs, kwargs

if torch.is_grad_enabled() and self.require_backward_grad_sync:
assert self.logger is not None
Expand Down Expand Up @@ -1456,6 +1430,33 @@ def _pre_forward(self):
# Notify joined ranks whether they should sync in backwards pass or not.
self._check_global_requires_backward_grad_sync(is_joined_rank=False)

if self.device_ids:
inputs, kwargs = _to_kwargs(
inputs,
kwargs,
torch.device(self.device_type, self.device_ids[0]),
self.use_side_stream_for_tensor_copies,
)
args, kwargs = inputs[0], kwargs[0] # type: ignore[index]
# Cast inputs to reduced precision if needed.
if self.mixed_precision is not None:
args, kwargs = _cast_forward_inputs(
self.mixed_precision.param_dtype,
*args,
**kwargs,
)
return args, kwargs
else:
# Cast inputs to reduced precision if needed.
# TODO (rohan-varma) test this codepath.
if self.mixed_precision is not None:
inputs, kwargs = _cast_forward_inputs(
self.mixed_precision.param_dtype,
*inputs,
**kwargs,
)
return inputs, kwargs

def _post_forward(self, output):
if self._delay_all_reduce_all_params:
self._clear_grad_buffer()
Expand Down Expand Up @@ -1528,7 +1529,7 @@ def _post_forward(self, output):

def forward(self, *inputs, **kwargs):
with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
self._pre_forward()
inputs, kwargs = self._pre_forward(*inputs, **kwargs)
output = (
self.module.forward(*inputs, **kwargs)
if self._delay_all_reduce_all_params
Expand Down