Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HSDP] Fix Node 1 unable receive parameters from Node 0 #108331

Closed
wants to merge 7 commits into from

Conversation

lxg2015
Copy link
Contributor

@lxg2015 lxg2015 commented Aug 31, 2023

When use hybrid_shard mode FSDP,
state.process_group means gpu_0,1,,,~,7 on node 0,so gpus on node 1 cannot receive parameters, setting process_group to default_group(global_group)can fix this issue

Fixes #ISSUE_NUMBER

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 31, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/108331

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ac714a3 with merge base 121cfb6 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Aug 31, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Aug 31, 2023
@awgu
Copy link
Contributor

awgu commented Aug 31, 2023

Thanks for pointing this issue out!

I think that the current broadcast semantics over only the sharding process group is an issue. However, I am not sure if using the global process group is the right thing to do in general.

I think the right semantics are to "union" the sharding process group and replication process group and broadcast over this unioned process group. When only using HSDP alone, this union is the global process group. However, if HSDP does not have full ownership over the cluster (e.g. if composing with some other parallelism), then the union may not be the global process group.

cc: @wanchaol @wz337 Would DeviceMesh support something like: We have a 2D submesh from some global mesh (possibly the global mesh is just the 2D mesh), and we call a collective over two dimensions of the mesh? Could DeviceMesh initialize new process groups if needed under the hood?

@lxg2015
Copy link
Contributor Author

lxg2015 commented Aug 31, 2023

@awgu You are right, there is indeed an issue with the global process group. I have modified the logic here.

Now, I will first synchronize the parameters on gpu0 to gpu1,2,...,7 on node0, and then according to state._inter_node_pg synchronize the parameters from gpu0 to gpu8, gpu1 to gpu9, gpu2 to gpu10, .... , so params on gpu0 broadcast to all ranks.

Is this all right ? thxs for your reply

@awgu
Copy link
Contributor

awgu commented Aug 31, 2023

@lxg2015 This approach seems reasonable to me. I am wondering if you can add a unit test in https://github.com/pytorch/pytorch/blob/main/test/distributed/fsdp/test_fsdp_hybrid_shard.py.

The unit test probably needs 4 GPUs (shard across 2 and replicate across 2).

@lxg2015
Copy link
Contributor Author

lxg2015 commented Sep 1, 2023

@awgu I add a unit test commit, and test successfully on 4 GPUs and 8 GPUs.

Copy link
Contributor

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @lxg2015!

model = fsdp_ctor(model)

with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
assert (model.lin1.weight == 0).all()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think the error messaging might be better if we use self.assertTrue(model.lin1.weight == 0).all()? Or at least, I think using self.assert<...> is the common practice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I have changed assert to self.assertTrue

@@ -516,6 +516,10 @@ def _init_param_handle_from_module(
_sync_module_params_and_buffers(
fully_sharded_module, managed_params, state.process_group
)
if hasattr(state, '_inter_node_pg'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I wonder if this might be brittle for checking if we are using HSDP. Perhaps, we can do getattr(state, "_inter_node_pg", None) is not None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I have changed this

@lxg2015
Copy link
Contributor Author

lxg2015 commented Sep 7, 2023

Hello @awgu , can this PR be merged? In order to reduce memory consumption, many other packages will only load the checkpoint on gpu 0, like accelerate. if we don't fix it, this will cause abnormal loss.

@awgu awgu self-assigned this Sep 7, 2023
@awgu
Copy link
Contributor

awgu commented Sep 7, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 7, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

@awgu
Copy link
Contributor

awgu commented Sep 7, 2023

@pytorchbot rebase -s

@pytorch pytorch deleted a comment from pytorch-bot bot Sep 7, 2023
@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

When use hybrid_shard mode FSDP, 
state.process_group means gpu_0,1,,,~,7 on node 0,so gpus on node 1 can not receive parameters, setting process_group to default_group(global_group)can fix this issue
first broadcast params from gpu0 to gpu1,2,,,7 on node0, 
then broadcast params from gpu0 to gpu8, gpu1 to gpu9, gpu2 to gpu10,  .... , so params on gpu0 broadcast to all ranks
works well when there are more nodes.
@pytorchmergebot
Copy link
Collaborator

Successfully rebased lxg2015-patch-3 onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout lxg2015-patch-3 && git pull --rebase)

@wz337
Copy link
Contributor

wz337 commented Sep 9, 2023

Thanks for pointing this issue out!

I think that the current broadcast semantics over only the sharding process group is an issue. However, I am not sure if using the global process group is the right thing to do in general.

I think the right semantics are to "union" the sharding process group and replication process group and broadcast over this unioned process group. When only using HSDP alone, this union is the global process group. However, if HSDP does not have full ownership over the cluster (e.g. if composing with some other parallelism), then the union may not be the global process group.

cc: @wanchaol @wz337 Would DeviceMesh support something like: We have a 2D submesh from some global mesh (possibly the global mesh is just the 2D mesh), and we call a collective over two dimensions of the mesh? Could DeviceMesh initialize new process groups if needed under the hood?

Yes. I think you can call a collective over two dimensions of the mesh. I believe Wanchao removes DeviceMesh's collective, since they were just a thin layer of functional collective. You should be able to use functional collective directly for this. If I understand your use case correctly, this may be something you are looking for. Code pointer: https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py#L161

And Yes, DeviceMesh would initialize new process group if needed. The current logic is that for mesh same as the world size, it will re-use if it has been initialized. For sub pgs, it will go throught the creation. See pointer:
https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/device_mesh.py#L213

cc. @wanchaol

@lxg2015
Copy link
Contributor Author

lxg2015 commented Sep 9, 2023

Hi @awgu, All checks appear to be fine, can this PR be merged?:grin: Or is DeviceMesh a better way to fix

@awgu
Copy link
Contributor

awgu commented Sep 11, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@haichaoyu
Copy link

haichaoyu commented Mar 22, 2024

Hi @lxg2015, @awgu , thanks for the PR and reviewing. One question here.

Here, _get_orig_params is called, and it does not return FlatParameter's. However, when some auto_wrap_policy is applied, there will actually be some FlatParameter's due to nested wrapping. In this case, not all model parameters will be successfully synced to make them consistent? Thanks!

@awgu
Copy link
Contributor

awgu commented Mar 25, 2024

@haichaoyu Each FullyShardedDataParallel instance should be responsible for syncing its own managed parameters, so you would want to pass sync_module_states=True for all FullyShardedDataParallel instances to broadcast the entire model from rank 0.

For example, suppose you had a transformer model with transformer blocks, where you apply FSDP to each transformer block and then finally to the root transformer. When you apply FSDP to each transformer block with sync_module_states=True, it broadcasts the transformer block's parameters from rank 0. Finally, when you apply FSDP to the root, it broadcasts the root's parameters (e.g. embedding weight, output projection weight) from rank 0 and does not re-broadcast the already flattened transformer block parameters.

Let me know if this makes sense!

@haichaoyu
Copy link

Got it. Thanks for detailed explanation!

@haichaoyu
Copy link

@awgu Another relevant question.
In the second time of calling _sync_module_params_and_buffers, buffer.FSDP_SYNCED is already set True and will not be synced between nodes. Does this cause inter-node inconsistency issue if model has randomly initialized buffers? Thanks!

@awgu
Copy link
Contributor

awgu commented Apr 9, 2024

If a parent module re-initializes the buffer of a child module, where the child module is part of a different FSDP wrapping, then yes, this could cause issues. The general guidance for FSDP is that each module should only initialize its directly owned parameters/buffers to avoid cases like this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: distributed (fsdp) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants