-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[DeviceMesh] Update slicing documentation to include nD and non-continuous slicing #132311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/132311
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 9ff80df with merge base 3864a2d ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: a9014ef Pull Request resolved: pytorch#132311
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should change our wording given we are removing parent mesh concept, maybe this PR could be combined with the parent -> root PR together
torch/distributed/device_mesh.py
Outdated
Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]). | ||
The following program runs on each process/rank in an SPMD manner in a world size of 8. | ||
In the first example: | ||
Calling mesh_2d["tp"] on rank 0, 1, 2, 3 returns a 1D child DeviceMesh:([0, 1, 2, 3]). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should change the wording from "child DeviceMesh" to sth like "sub DeviceMesh"? A child indicates a parent somewhere, where we should avoid using this term going forward. Please rename all the comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed "child DeviceMesh" and replaced with submesh instead.
I am not sure if we want to expose the concept of root mesh yet, as root mesh only lives in our mesh env right now. I think, we should leave out the term 'root mesh` from public API documentation.
[Edit]: If later we decided to have a public API to let users know whether a mesh is a root mesh, then we can include the term 'root_mesh' in the documentation.
torch/distributed/device_mesh.py
Outdated
""" | ||
Slice the current DeviceMesh based on the mesh_dim_name given to create a child | ||
DeviceMesh. | ||
Slice the current DeviceMesh based on the mesh_dim_names given to create a child |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slice the current DeviceMesh based on the mesh_dim_names given to create a child | |
Derive a sub-DeviceMesh based on the root DeviceMesh based on the ``mesh_dim_names`` provided |
torch/distributed/device_mesh.py
Outdated
Slice the current DeviceMesh based on the mesh_dim_name given to create a child | ||
DeviceMesh. | ||
Slice the current DeviceMesh based on the mesh_dim_names given to create a child | ||
DeviceMesh. This supports 1D or nD slicing. The DeviceMesh created will have dimensions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DeviceMesh. This supports 1D or nD slicing. The DeviceMesh created will have dimensions | |
The DeviceMesh created consists of the dimensions and the communicators indicated by ``mesh_dim_names`` |
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…root mesh (#132339) Previously, when we slice out a submesh from a mesh, we assign the mesh as the parent mesh of the submesh. In this case, when we have a 3D mesh topology, the parent mesh of a 1D mesh sliced out from the 3D mesh is different from the parent mesh of the same 1D mesh sliced out from the 2D submesh of the 3D mesh. For example: ``` mesh_3d = init_device_mesh("cuda", (2,2,2), ("dim0", "dim1", "dim2")) mesh_dim0 = mesh_3d["dim0"] mesh_2d = mesh_2d["dim0", "dim1"] mesh_dim0_2 = mesh_2d["dim0_2"] # This would evaluate to be True print(_mesh_resources.get_parent_mesh(mesh_dim0) != _mesh_resources.get_parent_mesh(mesh_dim0)) ``` We can always reconstruct the mesh needed from the mesh dim names, as long as two dims come from the same root. For simplicity, we do not see the necessity of building a tree structure to represent child-parent relationship. Therefore, we are replacing the parent mesh concept with a root mesh concept in `_MeshEnv` so we would have: ``` mesh_3d = init_device_mesh("cuda", (2,2,2), ("dim0", "dim1", "dim2")) mesh_dim0 = mesh_3d["dim0"] mesh_2d = mesh_2d["dim0", "dim1"] mesh_dim0_2 = mesh_2d["dim0_2"] # This would evaluate to be True print(_mesh_resources.get_root_mesh(mesh_dim0) == _mesh_resources.get_root_mesh(mesh_dim0)) ``` With this change, we will have two types of meshes in an environment. 1. `device_mesh != _mesh_resources.get_root_mesh(device_mesh)` means that the device_mesh is created by slicing. 2. `device_mesh == _mesh_resources.get_root_mesh(device_mesh)` means that the device_mesh is a root mesh not created through slicing. Pull Request resolved: #132339 Approved by: https://github.com/wanchaol ghstack dependencies: #132310, #132311
Adds a new private API to flatten a DeviceMesh to a 1D DeviceMesh such that: ``` mesh_3d = init_device_mesh( self.device_type, (2, 2, 2), mesh_dim_names=("dp", "cp", "tp"), ) dp_cp_mesh = mesh_3d["dp", "cp"] # flattened_mesh on rank 0, 2, 4, 6 is DeviceMesh([0, 2, 4, 6], mesh_dim_names=('dp_cp',)) # flattened_mesh on rank 1, 3, 5, 7 is DeviceMesh([1, 3, 5, 7], mesh_dim_names=('dp_cp',)) flattened_dp_cp_mesh = dp_cp_mesh._flatten() ``` Pull Request resolved: #132632 Approved by: https://github.com/fegin, https://github.com/wanchaol ghstack dependencies: #132310, #132311, #132339
Stack from ghstack (oldest at bottom):
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wconstab @d4l3k @c-p-i-o @tianyu-l @chauhang