Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove HSDP validation check #112435

Closed

Conversation

mvpatel2000
Copy link
Contributor

@mvpatel2000 mvpatel2000 commented Oct 30, 2023

Currently, HSDP validates that all intra/inter node PGs are the same. This makes sense if you are only using HSDP with no other forms of parallelism and is a nice but not necessary sanity check.

However, if you want to mix HSDP with other forms, say tensor parallelism on the FFN of a transformer block, the intra/inter node PGs will be different for that layer. This check raises errors in this scenario, so we need to remove this assumption.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 30, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112435

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e08159e with merge base d444a3b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@awgu
Copy link
Contributor

awgu commented Oct 30, 2023

If I understand correctly, there is still some value in the check; however, it is currently overly strict and is problematic for manual wrapping + HSDP. I think there was someone internally working on relaxing the check.

The check that is valuable is that if you are using HSDP, then each HSDP instance should use the same process groups if using the same ranks. We do not want to create a different pair of process groups per HSDP instance.

@mvpatel2000
Copy link
Contributor Author

@awgu got it! If someone is working on it, feel free to close this PR then :)

@awgu
Copy link
Contributor

awgu commented Oct 30, 2023

Let me follow-up on the progress on that PR and get back to you!

@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 2, 2023
@awgu
Copy link
Contributor

awgu commented Nov 15, 2023

@fegin @wz337 Is there anything from the checkpointing side that requires each FSDP instance to use the HSDP process groups?

If not, then I think removing this requirement/check sounds good to me (and we would need to remove the unit test).

@skip_if_lt_x_gpu(2)
def test_hybrid_shard_pg_mismatch_raises(self):
model = MyModel().cuda()
intra_pg = self.process_group
inter_pg = dist.new_group(ranks=[self.rank])
# Mismatched process groups for intra-node
model.lin1 = FSDP(
model.lin1,
process_group=(intra_pg, inter_pg),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
)
model = FSDP(
model,
process_group=(dist.new_group(), dist.new_group()),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
)
# Errors during _lazy_init
inp = torch.randn(4, 10)
with self.assertRaisesRegex(
ValueError, "intra-node process groups do not match"
):
model(inp)
# Mismatched process groups for inter-node
model = MyModel().cuda()
model.lin1 = FSDP(
model.lin1,
process_group=(intra_pg, inter_pg),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
)
model = FSDP(
model,
process_group=(intra_pg, dist.new_group()),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
)
with self.assertRaisesRegex(
ValueError, "inter-node process groups do not match"
):
model(inp)

@mvpatel2000
Copy link
Contributor Author

@fegin @wz337 Is there anything from the checkpointing side that requires each FSDP instance to use the HSDP process groups?

If not, then I think removing this requirement/check sounds good to me (and we would need to remove the unit test).

@skip_if_lt_x_gpu(2)
def test_hybrid_shard_pg_mismatch_raises(self):
model = MyModel().cuda()
intra_pg = self.process_group
inter_pg = dist.new_group(ranks=[self.rank])
# Mismatched process groups for intra-node
model.lin1 = FSDP(
model.lin1,
process_group=(intra_pg, inter_pg),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
)
model = FSDP(
model,
process_group=(dist.new_group(), dist.new_group()),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
)
# Errors during _lazy_init
inp = torch.randn(4, 10)
with self.assertRaisesRegex(
ValueError, "intra-node process groups do not match"
):
model(inp)
# Mismatched process groups for inter-node
model = MyModel().cuda()
model.lin1 = FSDP(
model.lin1,
process_group=(intra_pg, inter_pg),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
)
model = FSDP(
model,
process_group=(intra_pg, dist.new_group()),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
)
with self.assertRaisesRegex(
ValueError, "inter-node process groups do not match"
):
model(inp)

@awgu @fegin @wz337 bumping this request! would love to have this issue resolved

@mvpatel2000
Copy link
Contributor Author

@awgu @fegin @wz337 bumping this please!

@fegin fegin added ciflow/trunk Trigger trunk jobs on your pull request ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR labels Jan 30, 2024
Copy link

pytorch-bot bot commented Jan 30, 2024

Please seek CI approval before scheduling CIFlow labels

@pytorch-bot pytorch-bot bot removed the ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR label Jan 30, 2024
@fegin
Copy link
Contributor

fegin commented Jan 30, 2024

I think it is okay to remove the check. Will let @wz337 to review again.

@wz337
Copy link
Contributor

wz337 commented Jan 30, 2024

@fegin @wz337 Is there anything from the checkpointing side that requires each FSDP instance to use the HSDP process groups?

If not, then I think removing this requirement/check sounds good to me (and we would need to remove the unit test).

@skip_if_lt_x_gpu(2)
def test_hybrid_shard_pg_mismatch_raises(self):
model = MyModel().cuda()
intra_pg = self.process_group
inter_pg = dist.new_group(ranks=[self.rank])
# Mismatched process groups for intra-node
model.lin1 = FSDP(
model.lin1,
process_group=(intra_pg, inter_pg),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
)
model = FSDP(
model,
process_group=(dist.new_group(), dist.new_group()),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
)
# Errors during _lazy_init
inp = torch.randn(4, 10)
with self.assertRaisesRegex(
ValueError, "intra-node process groups do not match"
):
model(inp)
# Mismatched process groups for inter-node
model = MyModel().cuda()
model.lin1 = FSDP(
model.lin1,
process_group=(intra_pg, inter_pg),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
)
model = FSDP(
model,
process_group=(intra_pg, dist.new_group()),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
)
with self.assertRaisesRegex(
ValueError, "inter-node process groups do not match"
):
model(inp)

We are relying on the DTensor to do all_gather and chunk so we don't use the HSDP process groups directly. So I think it should be fine removing this requirement.

Copy link
Contributor

@wz337 wz337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's ok to remove this check.

Could you also include removing the unit test in the PR as @awgu mentioned so CI doesn't break?

https://github.com/pytorch/pytorch/blob/main/test/distributed/fsdp/test_fsdp_hybrid_shard.py#L120

@fegin fegin added the ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR label Jan 30, 2024
@github-actions github-actions bot added oncall: distributed Add this issue/PR to distributed oncall triage queue ciflow/inductor labels Feb 1, 2024
Copy link

pytorch-bot bot commented Feb 1, 2024

Please seek CI approval before scheduling CIFlow labels

@mvpatel2000
Copy link
Contributor Author

@wz337 test removed!

@wconstab sorry -- updated to be more clear :)

@mvpatel2000
Copy link
Contributor Author

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 3 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@Skylion007
Copy link
Collaborator

@pytorchbot merge -r

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased mvpatel2000/remove-hsdp-validate onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout mvpatel2000/remove-hsdp-validate && git pull --rebase)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorch-bot bot pushed a commit that referenced this pull request Feb 8, 2024
Currently, HSDP validates that all intra/inter node PGs are the same. This makes sense if you are only using HSDP with no other forms of parallelism and is a nice but not necessary sanity check.

However, if you want to mix HSDP with other forms, say tensor parallelism on the FFN of a transformer block, the intra/inter node PGs will be different for that layer. This check raises errors in this scenario, so we need to remove this assumption.

Pull Request resolved: #112435
Approved by: https://github.com/wz337, https://github.com/Skylion007
@mvpatel2000 mvpatel2000 deleted the mvpatel2000/remove-hsdp-validate branch February 13, 2024 19:06
mvpatel2000 added a commit to mvpatel2000/pytorch that referenced this pull request Feb 13, 2024
Currently, HSDP validates that all intra/inter node PGs are the same. This makes sense if you are only using HSDP with no other forms of parallelism and is a nice but not necessary sanity check.

However, if you want to mix HSDP with other forms, say tensor parallelism on the FFN of a transformer block, the intra/inter node PGs will be different for that layer. This check raises errors in this scenario, so we need to remove this assumption.

Pull Request resolved: pytorch#112435
Approved by: https://github.com/wz337, https://github.com/Skylion007
atalman pushed a commit that referenced this pull request Feb 14, 2024
Co-authored-by: Andrew Gu <andgu@fb.com>
resolved: #112435
resolved: #118620
Fixed `device_mesh` and auto wrap (#119064)
fix #118906.
resolved: #119064
resolved: #118638
Fixes #118639.
resolved: #119481
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (fsdp) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants