-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] Fix use_orig_params=True
, CPU offload, no_sync()
#100180
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -117,12 +117,6 @@ def _test_grad_acc( | |
point to prefetch the next layer's full parameters during the | ||
backward pass, if at all. | ||
""" | ||
# Gradient accumulation outside `no_sync()` is not currently compatible | ||
# with CPU offloading | ||
if cpu_offload.offload_params and any( | ||
not config.use_no_sync for config in configs | ||
): | ||
return | ||
# Initialize the FSDP model and optimizer | ||
fsdp_kwargs = { | ||
"cpu_offload": cpu_offload, | ||
|
@@ -183,9 +177,7 @@ def permute_tensor(x: torch.Tensor): | |
batch_idx = 0 | ||
for config in configs: | ||
sync_context = ( | ||
fsdp_model.no_sync() | ||
if config.use_no_sync | ||
else contextlib.suppress() | ||
fsdp_model.no_sync() if config.use_no_sync else contextlib.suppress() | ||
) | ||
with sync_context: | ||
for _ in range(config.num_iters): | ||
|
@@ -228,15 +220,11 @@ def _get_subtest_config(self) -> Dict[str, List[Any]]: | |
BackwardPrefetch.BACKWARD_PRE, | ||
BackwardPrefetch.BACKWARD_POST, | ||
], | ||
"cpu_offload": [ | ||
CPUOffload(offload_params=False), | ||
CPUOffload(offload_params=True), | ||
], | ||
"sharding_strategy": [ | ||
ShardingStrategy.FULL_SHARD, | ||
ShardingStrategy.SHARD_GRAD_OP, | ||
ShardingStrategy.NO_SHARD, | ||
] | ||
], | ||
} | ||
|
||
@skip_if_lt_x_gpu(2) | ||
|
@@ -266,20 +254,43 @@ def test_grad_acc( | |
use_orig_params: bool, | ||
): | ||
""" | ||
Tests gradient accumulation. | ||
Tests gradient accumulation without parameter CPU offloading. | ||
|
||
This exercises gradient accumulation inside and outside the | ||
``no_sync()`` context manager, in particular by interleaving the two. | ||
It tests both interleaving starting with (and ending with, resp.) | ||
inside versus outside ``no_sync()`` to ensure that initial conditions | ||
(and final conditions, resp.) do not affect the correctness. | ||
""" | ||
subtest_config = self._get_subtest_config() | ||
subtest_config["cpu_offload"] = [CPUOffload(offload_params=False)] | ||
self.run_subtests( | ||
subtest_config, | ||
self._test_grad_acc, | ||
batch_dim=1, | ||
configs=configs.configs, | ||
use_orig_params=use_orig_params, | ||
) | ||
|
||
@skip_if_lt_x_gpu(2) | ||
@parametrize("use_orig_params", [False, True]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we put this as a subtest? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can avoid subtesting major config differences, where we actually expect based on the implementation there may be a difference. At least that is what I have been doing so far. |
||
def test_grad_acc_cpu_offload( | ||
self, | ||
use_orig_params: bool, | ||
): | ||
""" | ||
Tests gradient accumulation with parameter CPU offloading. | ||
|
||
NOTE: Gradient accumulation without using the ``no_sync()`` context | ||
manager is not currently compatible with CPU offloading, so those tests | ||
just return directly. | ||
manager is not currently compatible with CPU offloading. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious, how do we error here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do not error. It is silently incorrect if I understand correctly. |
||
""" | ||
# Only test `no_sync` since outside `no_sync()` is not supported with | ||
# parameter CPU offloading | ||
configs = _GradAccConfigs([_GradAccConfig(use_no_sync=True, num_iters=3)]) | ||
subtest_config = self._get_subtest_config() | ||
subtest_config["cpu_offload"] = [CPUOffload(offload_params=True)] | ||
self.run_subtests( | ||
self._get_subtest_config(), | ||
subtest_config, | ||
self._test_grad_acc, | ||
batch_dim=1, | ||
configs=configs.configs, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1763,14 +1763,19 @@ def _use_unsharded_grad_views(self) -> None: | |
f"{self.flat_param._fqns[i]} is missing", | ||
) | ||
param = getattr(module, param_name) | ||
if param.shape != view.shape or param.dtype != view.dtype: | ||
# NOTE: This is a hack using `.data` to side step the | ||
# check that parameter/gradient sizes and dtypes match. Here, | ||
# `param` can have the sharded size, and `grad` can have the | ||
# unsharded size. Orthogonally, `param` can have the full | ||
# precision dtype from `reshard()`, and `grad` can have the | ||
# parameter low precision dtype. Both of these mismatches | ||
# happen when running in `no_sync()`. | ||
if ( | ||
param.shape != view.shape | ||
or param.dtype != view.dtype | ||
or param.device != view.device | ||
): | ||
# NOTE: This is a hack using `.data` to side step the check | ||
# that parameter/gradient sizes/dtypes/devices match. From | ||
# calling `reshard()`, `param` has the sharded size, the full | ||
# precision dtype, and is on CPU. Thus, one or more of the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is on CPU only if CPU offloading? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. |
||
# following cases can hold when in `no_sync()`: | ||
# 1. `view` can have the unsharded size. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me do it in a follow-up to avoid re-triggering CI. |
||
# 2. `view` can have the parameter low precision dtype. | ||
# 3. `view` can be on GPU. | ||
if param.grad is None: | ||
param.grad = torch.empty_like(param) | ||
param.grad.data = view | ||
|
@@ -1793,6 +1798,7 @@ def _use_unsharded_grad_views(self) -> None: | |
if ( | ||
param.shape != prim_param.grad.shape | ||
or param.dtype != prim_param.grad.dtype | ||
or param.device != prim_param.grad.device | ||
): | ||
# NOTE: This is the same hack to use `.data` to side step the | ||
# size check. | ||
|
@@ -2038,6 +2044,12 @@ def _writeback_orig_params(self) -> bool: | |
# For `NO_SHARD` + CPU offloading, `_cpu_grad` is always in | ||
# memory and owns the gradient storage, so it will never | ||
# require gradient writeback. | ||
if not self.uses_sharded_strategy and self._offload_params: | ||
# Explicitly continue to handle the case of `no_sync()`, | ||
# where `param.grad` is a view into the GPU gradient | ||
# referenced by `flat_param.grad`, while `flat_param_grad` | ||
# is `flat_param._cpu_grad`, which is on CPU | ||
continue | ||
needs_grad_writeback = ( | ||
flat_param_grad is None | ||
or not _same_storage_as_data_ptr( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously, we were skipping all CPU offloading tests since every tested config included
use_no_sync == False
😢That is why I thought
use_orig_params=True
worked withCPUOffload(True)
, but we were actually skipping the test.