Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Fix use_orig_params=True, CPU offload, no_sync() #100180

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 29 additions & 18 deletions test/distributed/fsdp/test_fsdp_grad_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,6 @@ def _test_grad_acc(
point to prefetch the next layer's full parameters during the
backward pass, if at all.
"""
# Gradient accumulation outside `no_sync()` is not currently compatible
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, we were skipping all CPU offloading tests since every tested config included use_no_sync == False 😢

That is why I thought use_orig_params=True worked with CPUOffload(True), but we were actually skipping the test.

# with CPU offloading
if cpu_offload.offload_params and any(
not config.use_no_sync for config in configs
):
return
# Initialize the FSDP model and optimizer
fsdp_kwargs = {
"cpu_offload": cpu_offload,
Expand Down Expand Up @@ -183,9 +177,7 @@ def permute_tensor(x: torch.Tensor):
batch_idx = 0
for config in configs:
sync_context = (
fsdp_model.no_sync()
if config.use_no_sync
else contextlib.suppress()
fsdp_model.no_sync() if config.use_no_sync else contextlib.suppress()
)
with sync_context:
for _ in range(config.num_iters):
Expand Down Expand Up @@ -228,15 +220,11 @@ def _get_subtest_config(self) -> Dict[str, List[Any]]:
BackwardPrefetch.BACKWARD_PRE,
BackwardPrefetch.BACKWARD_POST,
],
"cpu_offload": [
CPUOffload(offload_params=False),
CPUOffload(offload_params=True),
],
"sharding_strategy": [
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
ShardingStrategy.NO_SHARD,
]
],
}

@skip_if_lt_x_gpu(2)
Expand Down Expand Up @@ -266,20 +254,43 @@ def test_grad_acc(
use_orig_params: bool,
):
"""
Tests gradient accumulation.
Tests gradient accumulation without parameter CPU offloading.

This exercises gradient accumulation inside and outside the
``no_sync()`` context manager, in particular by interleaving the two.
It tests both interleaving starting with (and ending with, resp.)
inside versus outside ``no_sync()`` to ensure that initial conditions
(and final conditions, resp.) do not affect the correctness.
"""
subtest_config = self._get_subtest_config()
subtest_config["cpu_offload"] = [CPUOffload(offload_params=False)]
self.run_subtests(
subtest_config,
self._test_grad_acc,
batch_dim=1,
configs=configs.configs,
use_orig_params=use_orig_params,
)

@skip_if_lt_x_gpu(2)
@parametrize("use_orig_params", [False, True])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we put this as a subtest?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can avoid subtesting major config differences, where we actually expect based on the implementation there may be a difference. At least that is what I have been doing so far.

def test_grad_acc_cpu_offload(
self,
use_orig_params: bool,
):
"""
Tests gradient accumulation with parameter CPU offloading.

NOTE: Gradient accumulation without using the ``no_sync()`` context
manager is not currently compatible with CPU offloading, so those tests
just return directly.
manager is not currently compatible with CPU offloading.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, how do we error here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not error. It is silently incorrect if I understand correctly.

"""
# Only test `no_sync` since outside `no_sync()` is not supported with
# parameter CPU offloading
configs = _GradAccConfigs([_GradAccConfig(use_no_sync=True, num_iters=3)])
subtest_config = self._get_subtest_config()
subtest_config["cpu_offload"] = [CPUOffload(offload_params=True)]
self.run_subtests(
self._get_subtest_config(),
subtest_config,
self._test_grad_acc,
batch_dim=1,
configs=configs.configs,
Expand Down
28 changes: 20 additions & 8 deletions torch/distributed/fsdp/flat_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -1763,14 +1763,19 @@ def _use_unsharded_grad_views(self) -> None:
f"{self.flat_param._fqns[i]} is missing",
)
param = getattr(module, param_name)
if param.shape != view.shape or param.dtype != view.dtype:
# NOTE: This is a hack using `.data` to side step the
# check that parameter/gradient sizes and dtypes match. Here,
# `param` can have the sharded size, and `grad` can have the
# unsharded size. Orthogonally, `param` can have the full
# precision dtype from `reshard()`, and `grad` can have the
# parameter low precision dtype. Both of these mismatches
# happen when running in `no_sync()`.
if (
param.shape != view.shape
or param.dtype != view.dtype
or param.device != view.device
):
# NOTE: This is a hack using `.data` to side step the check
# that parameter/gradient sizes/dtypes/devices match. From
# calling `reshard()`, `param` has the sharded size, the full
# precision dtype, and is on CPU. Thus, one or more of the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is on CPU only if CPU offloading?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

# following cases can hold when in `no_sync()`:
# 1. `view` can have the unsharded size.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

view is the grad here right, can we clarify that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me do it in a follow-up to avoid re-triggering CI.

# 2. `view` can have the parameter low precision dtype.
# 3. `view` can be on GPU.
if param.grad is None:
param.grad = torch.empty_like(param)
param.grad.data = view
Expand All @@ -1793,6 +1798,7 @@ def _use_unsharded_grad_views(self) -> None:
if (
param.shape != prim_param.grad.shape
or param.dtype != prim_param.grad.dtype
or param.device != prim_param.grad.device
):
# NOTE: This is the same hack to use `.data` to side step the
# size check.
Expand Down Expand Up @@ -2038,6 +2044,12 @@ def _writeback_orig_params(self) -> bool:
# For `NO_SHARD` + CPU offloading, `_cpu_grad` is always in
# memory and owns the gradient storage, so it will never
# require gradient writeback.
if not self.uses_sharded_strategy and self._offload_params:
# Explicitly continue to handle the case of `no_sync()`,
# where `param.grad` is a view into the GPU gradient
# referenced by `flat_param.grad`, while `flat_param_grad`
# is `flat_param._cpu_grad`, which is on CPU
continue
needs_grad_writeback = (
flat_param_grad is None
or not _same_storage_as_data_ptr(
Expand Down