Skip to content

Commit

Permalink
[DDP] Perform input casting in pre forward (#100131)
Browse files Browse the repository at this point in the history
This is so that replicate can also have the feature to cast its
inputs, which it currently does not. Next diff will change replicate pre hook
to support this.

Differential Revision: [D45335179](https://our.internmc.facebook.com/intern/diff/D45335179/)
Pull Request resolved: #100131
Approved by: https://github.com/zhaojuanmao
  • Loading branch information
rohan-varma authored and pytorchmergebot committed Apr 27, 2023
1 parent ae0eb23 commit 87db02e
Showing 1 changed file with 32 additions and 31 deletions.
63 changes: 32 additions & 31 deletions torch/nn/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,34 +1369,8 @@ def _inside_ddp_forward(self):
DistributedDataParallel._active_ddp_module = None

def _run_ddp_forward(self, *inputs, **kwargs):
if self.device_ids:
inputs, kwargs = _to_kwargs(
inputs,
kwargs,
torch.device(self.device_type, self.device_ids[0]),
self.use_side_stream_for_tensor_copies,
)
args, kwargs = inputs[0], kwargs[0] # type: ignore[index]
# Cast inputs to reduced precision if needed.
if self.mixed_precision is not None:
args, kwargs = _cast_forward_inputs(
self.mixed_precision.param_dtype,
*args,
**kwargs,
)
with self._inside_ddp_forward():
return self.module(*args, **kwargs) # type: ignore[index]
else:
# Cast inputs to reduced precision if needed.
# TODO (rohan-varma) test this codepath.
if self.mixed_precision is not None:
inputs, kwargs = _cast_forward_inputs(
self.mixed_precision.param_dtype,
*inputs,
**kwargs,
)
with self._inside_ddp_forward():
return self.module(*inputs, **kwargs)
with self._inside_ddp_forward():
return self.module(*inputs, **kwargs) # type: ignore[index]

def _clear_grad_buffer(self):
# Making param.grad points to the grad buffers before backward is based on the
Expand All @@ -1419,9 +1393,9 @@ def _clear_grad_buffer(self):
if all_param_grad_none:
self._delay_grad_buffer.zero_()

def _pre_forward(self):
def _pre_forward(self, *inputs, **kwargs):
if self._delay_all_reduce_all_params:
return
return inputs, kwargs

if torch.is_grad_enabled() and self.require_backward_grad_sync:
assert self.logger is not None
Expand Down Expand Up @@ -1456,6 +1430,33 @@ def _pre_forward(self):
# Notify joined ranks whether they should sync in backwards pass or not.
self._check_global_requires_backward_grad_sync(is_joined_rank=False)

if self.device_ids:
inputs, kwargs = _to_kwargs(
inputs,
kwargs,
torch.device(self.device_type, self.device_ids[0]),
self.use_side_stream_for_tensor_copies,
)
args, kwargs = inputs[0], kwargs[0] # type: ignore[index]
# Cast inputs to reduced precision if needed.
if self.mixed_precision is not None:
args, kwargs = _cast_forward_inputs(
self.mixed_precision.param_dtype,
*args,
**kwargs,
)
return args, kwargs
else:
# Cast inputs to reduced precision if needed.
# TODO (rohan-varma) test this codepath.
if self.mixed_precision is not None:
inputs, kwargs = _cast_forward_inputs(
self.mixed_precision.param_dtype,
*inputs,
**kwargs,
)
return inputs, kwargs

def _post_forward(self, output):
if self._delay_all_reduce_all_params:
self._clear_grad_buffer()
Expand Down Expand Up @@ -1528,7 +1529,7 @@ def _post_forward(self, output):

def forward(self, *inputs, **kwargs):
with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
self._pre_forward()
inputs, kwargs = self._pre_forward(*inputs, **kwargs)
output = (
self.module.forward(*inputs, **kwargs)
if self._delay_all_reduce_all_params
Expand Down

0 comments on commit 87db02e

Please sign in to comment.