New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[HSDP] Fix Node 1 unable receive parameters from Node 0 #108331
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/108331
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ac714a3 with merge base 121cfb6 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Thanks for pointing this issue out! I think that the current broadcast semantics over only the sharding process group is an issue. However, I am not sure if using the global process group is the right thing to do in general. I think the right semantics are to "union" the sharding process group and replication process group and broadcast over this unioned process group. When only using HSDP alone, this union is the global process group. However, if HSDP does not have full ownership over the cluster (e.g. if composing with some other parallelism), then the union may not be the global process group. cc: @wanchaol @wz337 Would |
@awgu You are right, there is indeed an issue with the global process group. I have modified the logic here. Now, I will first synchronize the parameters on gpu0 to gpu1,2,...,7 on node0, and then according to state._inter_node_pg synchronize the parameters from gpu0 to gpu8, gpu1 to gpu9, gpu2 to gpu10, .... , so params on gpu0 broadcast to all ranks. Is this all right ? thxs for your reply |
@lxg2015 This approach seems reasonable to me. I am wondering if you can add a unit test in https://github.com/pytorch/pytorch/blob/main/test/distributed/fsdp/test_fsdp_hybrid_shard.py. The unit test probably needs 4 GPUs (shard across 2 and replicate across 2). |
@awgu I add a unit test commit, and test successfully on 4 GPUs and 8 GPUs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @lxg2015!
model = fsdp_ctor(model) | ||
|
||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): | ||
assert (model.lin1.weight == 0).all() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think the error messaging might be better if we use self.assertTrue(model.lin1.weight == 0).all()
? Or at least, I think using self.assert<...>
is the common practice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I have changed assert to self.assertTrue
@@ -516,6 +516,10 @@ def _init_param_handle_from_module( | |||
_sync_module_params_and_buffers( | |||
fully_sharded_module, managed_params, state.process_group | |||
) | |||
if hasattr(state, '_inter_node_pg'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I wonder if this might be brittle for checking if we are using HSDP. Perhaps, we can do getattr(state, "_inter_node_pg", None) is not None
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I have changed this
Hello @awgu , can this PR be merged? In order to reduce memory consumption, many other packages will only load the checkpoint on gpu 0, like accelerate. if we don't fix it, this will cause abnormal loss. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 mandatory check(s) failed. The first few are:
Dig deeper by viewing the failures on hud |
@pytorchbot rebase -s |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
When use hybrid_shard mode FSDP, state.process_group means gpu_0,1,,,~,7 on node 0,so gpus on node 1 can not receive parameters, setting process_group to default_group(global_group)can fix this issue
first broadcast params from gpu0 to gpu1,2,,,7 on node0, then broadcast params from gpu0 to gpu8, gpu1 to gpu9, gpu2 to gpu10, .... , so params on gpu0 broadcast to all ranks works well when there are more nodes.
Successfully rebased |
352fad9
to
ac714a3
Compare
Yes. I think you can call a collective over two dimensions of the mesh. I believe Wanchao removes DeviceMesh's collective, since they were just a thin layer of functional collective. You should be able to use functional collective directly for this. If I understand your use case correctly, this may be something you are looking for. Code pointer: https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py#L161 And Yes, DeviceMesh would initialize new process group if needed. The current logic is that for mesh same as the world size, it will re-use if it has been initialized. For sub pgs, it will go throught the creation. See pointer: cc. @wanchaol |
Hi @awgu, All checks appear to be fine, can this PR be merged?:grin: Or is DeviceMesh a better way to fix |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Hi @lxg2015, @awgu , thanks for the PR and reviewing. One question here. Here, |
@haichaoyu Each For example, suppose you had a transformer model with transformer blocks, where you apply FSDP to each transformer block and then finally to the root transformer. When you apply FSDP to each transformer block with Let me know if this makes sense! |
Got it. Thanks for detailed explanation! |
@awgu Another relevant question. |
If a parent module re-initializes the buffer of a child module, where the child module is part of a different FSDP wrapping, then yes, this could cause issues. The general guidance for FSDP is that each module should only initialize its directly owned parameters/buffers to avoid cases like this. |
When use hybrid_shard mode FSDP,
state.process_group means gpu_0,1,,,~,7 on node 0,so gpus on node 1 cannot receive parameters, setting process_group to default_group(global_group)can fix this issue
Fixes #ISSUE_NUMBER