Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Fix "use-after-free" in reshard logic #94859

Closed
wants to merge 1 commit into from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Feb 14, 2023

Stack from ghstack:

Overview
This PR switches the order of freeing the unsharded FlatParameter (self._free_unsharded_flat_param()) and switching to use the sharded FlatParameter (self._use_sharded_flat_param()). This is to prevent "use-after_free"-type bugs where for param.data = new_data, param has its metadata intact but not its storage, causing an illegal memory access for any instrumentation that depends on its storage. (param is an original parameter and new_data is either a view into the sharded FlatParameter or torch.empty(0) depending on the sharding and rank.)

Details
To see why simply switching the order of the two calls is safe, let us examine the calls themselves:

def _use_sharded_flat_param(self) -> None:
"""Switches to using the sharded flattened parameter."""
flat_param = self.flat_param
if self._offload_params:
device = flat_param._local_shard.device # type: ignore[attr-defined]
p_assert(
device == torch.device("cpu"),
f"Expects the local shard to be on CPU but got {device}",
)
flat_param.data = flat_param._local_shard # type: ignore[attr-defined]
if self._use_orig_params:
self._use_sharded_views()
# For the post-forward reshard, we may try to use sharded gradient
# views (or unsharded gradient views if a gradient was accumulated
# in `no_sync()`), but for the post-backward reshard, we delay the
# call to after the reduce-scatter.
if self._training_state == HandleTrainingState.FORWARD:
# TODO: Change `_unpadded_unsharded_size` if we change the
# gradient to be computed directly with padding.
accumulated_grad_in_no_sync = (
flat_param.grad is not None
and self.uses_sharded_strategy
and flat_param.grad.shape == flat_param._unpadded_unsharded_size
)
if accumulated_grad_in_no_sync:
self._use_unsharded_grad_views()
else:
self._use_sharded_grad_views()

def _free_unsharded_flat_param(self):
"""
Frees the padded unsharded flattened parameter. The tensor to free
depends on the calling context since the unshard may have forced full
precision, in which case a different tensor is used.
"""
self._check_sharded_strategy()
unsharded_flat_param = self._get_padded_unsharded_flat_param()
self._check_storage_allocated(unsharded_flat_param)
self._check_on_compute_device(unsharded_flat_param)
# Do not free the memory until all ops in the current stream finish
_no_dispatch_record_stream(unsharded_flat_param, torch.cuda.current_stream())
_free_storage(unsharded_flat_param)

  • _free_unsharded_flat_param() does not make any assumption that self.flat_param's data is the sharded FlatParameter (i.e. _local_shard).
  • The sharded FlatParameter (i.e. _local_shard) is always present in memory, which means that FSDP can use sharded views at any time, including before freeing the unsharded data.

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 14, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94859

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 Failures

As of commit 652457b:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Feb 14, 2023
awgu added a commit that referenced this pull request Feb 14, 2023
ghstack-source-id: 3708f45d036b38eab8c61f89eb61381128251dac
Pull Request resolved: #94859
@awgu awgu added topic: bug fixes topic category ciflow/trunk Trigger trunk jobs on your pull request labels Feb 14, 2023
@awgu
Copy link
Contributor Author

awgu commented Feb 15, 2023

Failures look unrelated:

manywheel-py3_8-cuda11_7-test / test

Test that linalg works
+ python -c 'import torch;x=torch.rand(3,3,device='\''cuda'\'');print(torch.linalg.svd(torch.mm(x.t(), x)))'
Traceback (most recent call last):
  File "<string>", line 1, in <module>
RuntimeError: Error in dlopen: /opt/python/cp38-cp38/lib/python3.8/site-packages/torch/lib/libtorch_cuda_linalg.so: undefined symbol: sgebak_

manywheel-py3_8-cuda11_7-with-pypi-cudnn-test / test

Test that linalg works
+ python -c 'import torch;x=torch.rand(3,3,device='\''cuda'\'');print(torch.linalg.svd(torch.mm(x.t(), x)))'
Traceback (most recent call last):
  File "<string>", line 1, in <module>
RuntimeError: Error in dlopen: /opt/python/cp38-cp38/lib/python3.8/site-packages/torch/lib/libtorch_cuda_linalg.so: undefined symbol: sgebak_

@awgu awgu marked this pull request as ready for review February 15, 2023 12:13
@awgu
Copy link
Contributor Author

awgu commented Feb 15, 2023

I confirmed that the internal run completed without error.

@awgu
Copy link
Contributor Author

awgu commented Feb 15, 2023

@pytorchbot merge -f "unrelated failures"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/awgu/328/head branch June 8, 2023 15:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category topic: bug fixes topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants