Skip to content

Commit

Permalink
[FSDP] Fix use_orig_params=True, CPU offload, no_sync() (#100180)
Browse files Browse the repository at this point in the history
This should fix #98494. We follow a similar approach as in past PRs for mismatched dtype or size from running in `no_sync()`.
Pull Request resolved: #100180
Approved by: https://github.com/rohan-varma
  • Loading branch information
awgu authored and pytorchmergebot committed May 1, 2023
1 parent e779a30 commit 83b803c
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 22 deletions.
41 changes: 27 additions & 14 deletions test/distributed/fsdp/test_fsdp_grad_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,6 @@ def _test_grad_acc(
point to prefetch the next layer's full parameters during the
backward pass, if at all.
"""
# Gradient accumulation outside `no_sync()` is not currently compatible
# with CPU offloading
if cpu_offload.offload_params and any(
not config.use_no_sync for config in configs
):
return
# Initialize the FSDP model and optimizer
fsdp_kwargs = {
"cpu_offload": cpu_offload,
Expand Down Expand Up @@ -226,10 +220,6 @@ def _get_subtest_config(self) -> Dict[str, List[Any]]:
BackwardPrefetch.BACKWARD_PRE,
BackwardPrefetch.BACKWARD_POST,
],
"cpu_offload": [
CPUOffload(offload_params=False),
CPUOffload(offload_params=True),
],
"sharding_strategy": [
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
Expand Down Expand Up @@ -264,20 +254,43 @@ def test_grad_acc(
use_orig_params: bool,
):
"""
Tests gradient accumulation.
Tests gradient accumulation without parameter CPU offloading.
This exercises gradient accumulation inside and outside the
``no_sync()`` context manager, in particular by interleaving the two.
It tests both interleaving starting with (and ending with, resp.)
inside versus outside ``no_sync()`` to ensure that initial conditions
(and final conditions, resp.) do not affect the correctness.
"""
subtest_config = self._get_subtest_config()
subtest_config["cpu_offload"] = [CPUOffload(offload_params=False)]
self.run_subtests(
subtest_config,
self._test_grad_acc,
batch_dim=1,
configs=configs.configs,
use_orig_params=use_orig_params,
)

@skip_if_lt_x_gpu(2)
@parametrize("use_orig_params", [False, True])
def test_grad_acc_cpu_offload(
self,
use_orig_params: bool,
):
"""
Tests gradient accumulation with parameter CPU offloading.
NOTE: Gradient accumulation without using the ``no_sync()`` context
manager is not currently compatible with CPU offloading, so those tests
just return directly.
manager is not currently compatible with CPU offloading.
"""
# Only test `no_sync` since outside `no_sync()` is not supported with
# parameter CPU offloading
configs = _GradAccConfigs([_GradAccConfig(use_no_sync=True, num_iters=3)])
subtest_config = self._get_subtest_config()
subtest_config["cpu_offload"] = [CPUOffload(offload_params=True)]
self.run_subtests(
self._get_subtest_config(),
subtest_config,
self._test_grad_acc,
batch_dim=1,
configs=configs.configs,
Expand Down
28 changes: 20 additions & 8 deletions torch/distributed/fsdp/flat_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -1772,14 +1772,19 @@ def _use_unsharded_grad_views(self) -> None:
f"{self.flat_param._fqns[i]} is missing",
)
param = getattr(module, param_name)
if param.shape != view.shape or param.dtype != view.dtype:
# NOTE: This is a hack using `.data` to side step the
# check that parameter/gradient sizes and dtypes match. Here,
# `param` can have the sharded size, and `grad` can have the
# unsharded size. Orthogonally, `param` can have the full
# precision dtype from `reshard()`, and `grad` can have the
# parameter low precision dtype. Both of these mismatches
# happen when running in `no_sync()`.
if (
param.shape != view.shape
or param.dtype != view.dtype
or param.device != view.device
):
# NOTE: This is a hack using `.data` to side step the check
# that parameter/gradient sizes/dtypes/devices match. From
# calling `reshard()`, `param` has the sharded size, the full
# precision dtype, and is on CPU. Thus, one or more of the
# following cases can hold when in `no_sync()`:
# 1. `view` can have the unsharded size.
# 2. `view` can have the parameter low precision dtype.
# 3. `view` can be on GPU.
if param.grad is None:
param.grad = torch.empty_like(param)
param.grad.data = view
Expand All @@ -1802,6 +1807,7 @@ def _use_unsharded_grad_views(self) -> None:
if (
param.shape != prim_param.grad.shape
or param.dtype != prim_param.grad.dtype
or param.device != prim_param.grad.device
):
# NOTE: This is the same hack to use `.data` to side step the
# size check.
Expand Down Expand Up @@ -2047,6 +2053,12 @@ def _writeback_orig_params(self) -> bool:
# For `NO_SHARD` + CPU offloading, `_cpu_grad` is always in
# memory and owns the gradient storage, so it will never
# require gradient writeback.
if not self.uses_sharded_strategy and self._offload_params:
# Explicitly continue to handle the case of `no_sync()`,
# where `param.grad` is a view into the GPU gradient
# referenced by `flat_param.grad`, while `flat_param_grad`
# is `flat_param._cpu_grad`, which is on CPU
continue
needs_grad_writeback = (
flat_param_grad is None
or not _same_storage_as_data_ptr(
Expand Down

0 comments on commit 83b803c

Please sign in to comment.