-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[DeviceMesh] Clarifying flatten use case #161311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161311
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ffb69f4 with merge base 5babb4d ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
I really thought we allow |
Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding: 1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached). 2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in #160709 but it does not fixed the check for the case when we call the `_flatten` twice. What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why? 1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__). 2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that line will never be reached if we error out before that. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta [ghstack-poisoned]
Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding: 1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached). 2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in #160709 but it does not fixed the check for the case when we call the `_flatten` twice. What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why? 1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__). 2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that line will never be reached if we error out before that. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta [ghstack-poisoned]
Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding: 1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached). 2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in #160709 but it does not fixed the check for the case when we call the `_flatten` twice. What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why? 1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__). 2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that line will never be reached if we error out before that. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta [ghstack-poisoned]
Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding: 1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached). 2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in #160709 but it does not fixed the check for the case when we call the `_flatten` twice. What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why? 1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__). 2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that line will never be reached if we error out before that. Also we should allow a no-op for flatten a 1D mesh into itself's mesh_dim_name, I added a unit test for it. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta [ghstack-poisoned]
device_mesh.mesh_dim_names | ||
): | ||
return device_mesh | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, what if users flatten 1D but with a different name? What's the behavior we should expect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then we throw exception that is an invalid use right? That's why we have the second check after and.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking about the case where users flattened the (2,4) mesh and somewhere else flattened the (4,2) mesh- in their mind there might not be a clear relationship between the 2 (8,) meshes they created. Perhaps they want to use different names for them, for the logical use case for those meshes. However we should still allow the flatten and reuse the PG if possible. Maybe this is not a common case, but I also don't see the reason we need to error for flattening twice with different names. That's my thinking anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wconstab I see, that is outside the scope of this PR. For this one I just want to fix the current behavior for now.
Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding: 1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached). 2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in #160709 but it does not fixed the check for the case when we call the `_flatten` twice. What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why? 1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__). 2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that line will never be reached if we error out before that. Also we should allow a no-op for flatten a 1D mesh into itself's mesh_dim_name, I added a unit test for it. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding: 1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached). 2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in pytorch#160709 but it does not fixed the check for the case when we call the `_flatten` twice. What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why? 1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__). 2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that line will never be reached if we error out before that. Also we should allow a no-op for flatten a 1D mesh into itself's mesh_dim_name, I added a unit test for it. Pull Request resolved: pytorch#161311 Approved by: https://github.com/fegin
Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding: 1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached). 2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in pytorch#160709 but it does not fixed the check for the case when we call the `_flatten` twice. What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why? 1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__). 2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that line will never be reached if we error out before that. Also we should allow a no-op for flatten a 1D mesh into itself's mesh_dim_name, I added a unit test for it. Pull Request resolved: pytorch#161311 Approved by: https://github.com/fegin
Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding: 1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached). 2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in pytorch#160709 but it does not fixed the check for the case when we call the `_flatten` twice. What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why? 1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__). 2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that line will never be reached if we error out before that. Also we should allow a no-op for flatten a 1D mesh into itself's mesh_dim_name, I added a unit test for it. Pull Request resolved: pytorch#161311 Approved by: https://github.com/fegin
Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding: 1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached). 2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in pytorch#160709 but it does not fixed the check for the case when we call the `_flatten` twice. What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why? 1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__). 2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that line will never be reached if we error out before that. Also we should allow a no-op for flatten a 1D mesh into itself's mesh_dim_name, I added a unit test for it. Pull Request resolved: pytorch#161311 Approved by: https://github.com/fegin
Stack from ghstack (oldest at bottom):
Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding:
dp_cp_mesh._flatten()
many times but no backend will be created (aka cached)._flatten
more than once. But there is bug which was partially fixed in Do not incorrectly chain each of the strings as iterables #160709 but it does not fixed the check for the case when we call the_flatten
twice.What's more important question to ask is, what behavior we want for
_flatten
? Do we allow calling_flatten
multiple times (with same mesh_name)? I think we should, why?root_to_flatten_mapping
and actually do the early return but that line will never be reached if we error out before that.Also we should allow a no-op for flatten a 1D mesh into itself's mesh_dim_name, I added a unit test for it.
cc @H-Huang @awgu @wanchaol @fegin @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci