Skip to content

Conversation

fduwjj
Copy link
Contributor

@fduwjj fduwjj commented Aug 22, 2025

Stack from ghstack (oldest at bottom):

Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding:

  1. In unit test, we assume users can call dp_cp_mesh._flatten() many times but no backend will be created (aka cached).
  2. From the implementation of slicing, we actually throw exception erroring out doing the _flatten more than once. But there is bug which was partially fixed in Do not incorrectly chain each of the strings as iterables #160709 but it does not fixed the check for the case when we call the _flatten twice.

What's more important question to ask is, what behavior we want for _flatten? Do we allow calling _flatten multiple times (with same mesh_name)? I think we should, why?

  1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to eq).
  2. We actually cached the flattened mesh today inside root_to_flatten_mapping and actually do the early return but that line will never be reached if we error out before that.

Also we should allow a no-op for flatten a 1D mesh into itself's mesh_dim_name, I added a unit test for it.

cc @H-Huang @awgu @wanchaol @fegin @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci

Copy link

pytorch-bot bot commented Aug 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161311

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ffb69f4 with merge base 5babb4d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

fduwjj added a commit that referenced this pull request Aug 22, 2025
ghstack-source-id: 208f663
Pull Request resolved: #161311
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Aug 22, 2025
@fegin
Copy link
Contributor

fegin commented Aug 22, 2025

I really thought we allow flattening twice. If we already cached the flattened ones, it should be no harm to allow flattening twice.

Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding:
1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached).
2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in #160709 but it does not fixed the check for the case when we call the `_flatten` twice.

What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why?
1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__).
2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that  line will never be reached if we error out before that.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Aug 22, 2025
ghstack-source-id: 24ae4ed
Pull Request resolved: #161311
Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding:
1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached).
2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in #160709 but it does not fixed the check for the case when we call the `_flatten` twice.

What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why?
1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__).
2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that  line will never be reached if we error out before that.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Aug 22, 2025
ghstack-source-id: 65a1c89
Pull Request resolved: #161311
Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding:
1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached).
2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in #160709 but it does not fixed the check for the case when we call the `_flatten` twice.

What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why?
1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__).
2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that  line will never be reached if we error out before that.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Aug 23, 2025
ghstack-source-id: f48611f
Pull Request resolved: #161311
@fduwjj fduwjj requested a review from wconstab August 23, 2025 02:30
@fduwjj fduwjj added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 23, 2025
Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding:
1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached).
2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in #160709 but it does not fixed the check for the case when we call the `_flatten` twice.

What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why?
1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__).
2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that  line will never be reached if we error out before that.

Also we should allow a no-op for flatten a 1D mesh into itself's mesh_dim_name, I added a unit test for it.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Aug 23, 2025
ghstack-source-id: 852444b
Pull Request resolved: #161311
device_mesh.mesh_dim_names
):
return device_mesh

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, what if users flatten 1D but with a different name? What's the behavior we should expect?

Copy link
Contributor Author

@fduwjj fduwjj Aug 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we throw exception that is an invalid use right? That's why we have the second check after and.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about the case where users flattened the (2,4) mesh and somewhere else flattened the (4,2) mesh- in their mind there might not be a clear relationship between the 2 (8,) meshes they created. Perhaps they want to use different names for them, for the logical use case for those meshes. However we should still allow the flatten and reuse the PG if possible. Maybe this is not a common case, but I also don't see the reason we need to error for flattening twice with different names. That's my thinking anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wconstab I see, that is outside the scope of this PR. For this one I just want to fix the current behavior for now.

@albanD albanD removed the DeviceMesh label Sep 2, 2025
Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding:
1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached).
2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in #160709 but it does not fixed the check for the case when we call the `_flatten` twice.

What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why?
1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__).
2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that  line will never be reached if we error out before that.

Also we should allow a no-op for flatten a 1D mesh into itself's mesh_dim_name, I added a unit test for it.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Sep 10, 2025
ghstack-source-id: e2c998f
Pull Request resolved: #161311
@fduwjj
Copy link
Contributor Author

fduwjj commented Sep 10, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding:
1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached).
2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in pytorch#160709 but it does not fixed the check for the case when we call the `_flatten` twice.

What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why?
1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__).
2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that  line will never be reached if we error out before that.

Also we should allow a no-op for flatten a 1D mesh into itself's mesh_dim_name, I added a unit test for it.

Pull Request resolved: pytorch#161311
Approved by: https://github.com/fegin
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding:
1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached).
2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in pytorch#160709 but it does not fixed the check for the case when we call the `_flatten` twice.

What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why?
1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__).
2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that  line will never be reached if we error out before that.

Also we should allow a no-op for flatten a 1D mesh into itself's mesh_dim_name, I added a unit test for it.

Pull Request resolved: pytorch#161311
Approved by: https://github.com/fegin
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding:
1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached).
2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in pytorch#160709 but it does not fixed the check for the case when we call the `_flatten` twice.

What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why?
1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__).
2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that  line will never be reached if we error out before that.

Also we should allow a no-op for flatten a 1D mesh into itself's mesh_dim_name, I added a unit test for it.

Pull Request resolved: pytorch#161311
Approved by: https://github.com/fegin
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding:
1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached).
2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in pytorch#160709 but it does not fixed the check for the case when we call the `_flatten` twice.

What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why?
1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__).
2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that  line will never be reached if we error out before that.

Also we should allow a no-op for flatten a 1D mesh into itself's mesh_dim_name, I added a unit test for it.

Pull Request resolved: pytorch#161311
Approved by: https://github.com/fegin
@github-actions github-actions bot deleted the gh/fduwjj/191/head branch October 11, 2025 02:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: DeviceMesh

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants