Skip to content

Commit

Permalink
[FSDP] Option to keep grads in lower precision
Browse files Browse the repository at this point in the history
Differential Revision: [D39529117](https://our.internmc.facebook.com/intern/diff/D39529117/)

ghstack-source-id: 167439055
Pull Request resolved: #85062
  • Loading branch information
rohan-varma committed Sep 15, 2022
1 parent f030988 commit 9789016
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
35 changes: 35 additions & 0 deletions test/distributed/fsdp/test_fsdp_mixed_precision.py
Expand Up @@ -286,6 +286,33 @@ def _reduce_scatter_base_validate_mp(

return orig_reduce_scatter(*args, **kwargs)

def _test_grads_reduced_precision(self):
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.lin1 = nn.Linear(10, 10)
self.lin2 = nn.Linear(10, 10)

def forward(self, x):
return self.lin2(self.lin1(x))

m = MyModel().cuda()
mp = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
keep_casted_gradients=True,
)
m.lin1 = FSDP(m.lin1, mixed_precision=mp)
m = FSDP(m, mixed_precision=mp)
for _ in range(6):
inp = torch.ones(1, 10)
m(inp).sum().backward()
for param in m.parameters():
self.assertEqual(torch.bfloat16, param.grad.dtype)

dist.barrier()

def _run_test_mixed_precision_e2e(
self,
mp_config,
Expand Down Expand Up @@ -576,6 +603,10 @@ def test_mixed_precision_resnet(self):
loss = fsdp(inp).sum()
loss.backward()

@skip_if_lt_x_gpu(2)
def test_grads_reduced_precision(self):
self._test_grads_reduced_precision()

@skip_if_lt_x_gpu(2)
@parametrize("convert_sync_bn", [True, False])
def test_mp_batchnorm(self, convert_sync_bn):
Expand Down Expand Up @@ -641,6 +672,10 @@ class TestFSDPMixedPrecisionUnsharded(TestFSDPMixedPrecision):
def world_size(self):
return 1

@skip_if_lt_x_gpu(1)
def test_grads_reduced_precision(self):
return self._test_grads_reduced_precision()

@skip_if_lt_x_gpu(1)
def test_mixed_precision_no_reshard_after_forward(self):
# Note that we don't exercise all possible different configs so as to
Expand Down
34 changes: 33 additions & 1 deletion torch/distributed/fsdp/fully_sharded_data_parallel.py
Expand Up @@ -199,6 +199,7 @@ class MixedPrecision:
# TODO: buffer + param are usually of the same type, if user specifies
# param but not buffer, should we automatically make buffer be the same?
buffer_dtype: Optional[torch.dtype] = None
keep_casted_gradients: Optional[bool] = False


@dataclass
Expand Down Expand Up @@ -1280,6 +1281,12 @@ def _mixed_precision_enabled_for_reduce(self) -> bool:
and self.mixed_precision.reduce_dtype is not None
)

def _mixed_precision_keep_low_precision_grads(self) -> bool:
return (
self.mixed_precision is not None
and self.mixed_precision.keep_casted_gradients
)

def _low_precision_hook_enabled(self) -> bool:
"""
Wether a low precision hook is registered or not.
Expand Down Expand Up @@ -2929,7 +2936,13 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
if self.sharding_strategy == ShardingStrategy.NO_SHARD:
self._communication_hook(self._communication_hook_state, param.grad)

self._cast_grad_to_param_dtype(param.grad, param)
# For NO_SHARD keeping grads in the reduced precision, we
# can simply omit the cast as needed, we can't do this for
# other sharding strategies because grad field is assigned
# in _finalize_params. TODO (rvarm1) this divergence in
# logic is not ideal.
if not self._mixed_precision_keep_low_precision_grads():
self._cast_grad_to_param_dtype(param.grad, param)

# Regardless of sharding or not, offload the grad to CPU if we are
# offloading params. This is so param and grad reside on same device
Expand Down Expand Up @@ -3089,6 +3102,11 @@ def _finalize_params(fsdp_module: FullyShardedDataParallel) -> None:
# lands. If it was not called, there is no new gradient to accumulate
if p._post_backward_called:
p.grad = p._saved_grad_shard

if fsdp_module._mixed_precision_keep_low_precision_grads():
p.grad.data = p.grad.data.to(
fsdp_module.mixed_precision.param_dtype
)
else:
p_assert(
not p._is_sharded or not p._post_backward_called,
Expand Down Expand Up @@ -3422,6 +3440,20 @@ def _prep_grads_for_backward(self) -> None:
# warning in the class's docstring).
if not offloaded:
p._saved_grad_shard = p.grad.data # type: ignore[attr-defined]
# If we're using mixed precision with keeping grads
# casted, gradient here might still be of the reduced
# dtype if we didn't clear / set the gradients to None
# after previous forward. In that case, make sure
# p._saved_grad_shard is cast to the full precision type
# so that we can accumulate in full precision in
# _post_backward_hook and assign back in full precision
# in _wait_for_post_backward.
if (
self._mixed_precision_keep_low_precision_grads() and
p._saved_grad_shard.dtype != p._local_shard.dtype
):
p._saved_grad_shard = p._saved_grad_shard.to(p._local_shard.dtype)

p.grad = None

def _should_free_full_params(self):
Expand Down

0 comments on commit 9789016

Please sign in to comment.