-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] Fix "use-after-free" in reshard logic #94859
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94859
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 FailuresAs of commit 652457b: NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 3708f45d036b38eab8c61f89eb61381128251dac Pull Request resolved: #94859
Failures look unrelated: manywheel-py3_8-cuda11_7-test / test
manywheel-py3_8-cuda11_7-with-pypi-cudnn-test / test
|
I confirmed that the internal run completed without error. |
@pytorchbot merge -f "unrelated failures" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack:
Overview
This PR switches the order of freeing the unsharded
FlatParameter
(self._free_unsharded_flat_param()
) and switching to use the shardedFlatParameter
(self._use_sharded_flat_param()
). This is to prevent "use-after_free"-type bugs where forparam.data = new_data
,param
has its metadata intact but not its storage, causing an illegal memory access for any instrumentation that depends on its storage. (param
is an original parameter andnew_data
is either a view into the shardedFlatParameter
ortorch.empty(0)
depending on the sharding and rank.)Details
To see why simply switching the order of the two calls is safe, let us examine the calls themselves:
pytorch/torch/distributed/fsdp/flat_param.py
Lines 1312 to 1339 in 652457b
pytorch/torch/distributed/fsdp/flat_param.py
Lines 1298 to 1310 in 652457b
_free_unsharded_flat_param()
does not make any assumption thatself.flat_param
's data is the shardedFlatParameter
(i.e._local_shard
).FlatParameter
(i.e._local_shard
) is always present in memory, which means that FSDP can use sharded views at any time, including before freeing the unsharded data.