-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[DeviceMesh] Simplifying internal bookkeeping with CuTe layout #163213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163213
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a51ea5e with merge base f63d16c ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…yout" We want to refactor the internal bookkeeping of DeviceMesh so that: Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible. Concretely, in this PR, we do the following: 1. Use the _Layout to handle all index operations rather use a map to record mesh dims. 2. Replaced`flatten_name_to_root_dims` with `flatten_name_to_root_layout`. Basically one (size, stride) pair maps to one PG. One mesh_dim_name can only map to only layout. (More than one mesh_dim_name can map to the same layout). 3. Replaced `_get_slice_mesh_dims` with `_get_slice_mesh_layout`. 4. Use a new function `check_overlap` to check layout overlap. 5. Use a new function `to_remapping_tensor` to use layout ranks as indices when the mesh tensor is not representable as CuTe. The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor. This is a continue of #161106 (originally messed with EasyCLA) cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…yout" We want to refactor the internal bookkeeping of DeviceMesh so that: Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible. Concretely, in this PR, we do the following: 1. Use the _Layout to handle all index operations rather use a map to record mesh dims. 2. Replaced`flatten_name_to_root_dims` with `flatten_name_to_root_layout`. Basically one (size, stride) pair maps to one PG. One mesh_dim_name can only map to only layout. (More than one mesh_dim_name can map to the same layout). 3. Replaced `_get_slice_mesh_dims` with `_get_slice_mesh_layout`. 4. Use a new function `check_overlap` to check layout overlap. 5. Use a new function `to_remapping_tensor` to use layout ranks as indices when the mesh tensor is not representable as CuTe. The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor. With this refactoring we also enabled the slicing and flatten of non-contiguous dims of a device mesh which is hard to implement without cute layout. This is a continue of #161106 (originally messed with EasyCLA) cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
return slice_mesh_dims | ||
return layout_sliced | ||
|
||
def _get_all_submeshes( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this function ever used, other in test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is used internally by shampoo. cc: @wz337
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add TODO to these private APIs used by other components. This mostly represents our DeviceMesh should public these use cases in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we get clarity on which context this is needed in and whether the same can be achieved in other ways?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's what the code usage in Shampoo looks like:
# NOTE: We get all submeshes along the "replicate" dimension, then pick out
# the sub-mesh that the optimizer state is assigned to.
#
# For the example above, this would give me submeshes [[3, 27], [11, 35], [19, 43]].
# Note that the group source rank must belong to {0, 1, 2} in this case.
# Suppose the group_source_rank = 1, then this would get the submesh [11, 35].
replicate_submesh = _mesh_resources._get_all_submeshes(
device_mesh_2d, "replicate"
)[group_source_rank]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels like a legitimate use case although I don't know exactly how to spell this. It feels like what is happening here is they want to do a submesh slice on replicate, but if they do it directly as _mesh_resources["replicate"]
they will get the submesh for the CURRENT rank. But instead, they want to run it for group_source_rank
. I feel this should be directly supportable, but it does interact with the current idea that a submesh is "only" a Layout, but it is also a particular coordinate picked out of the complement.
torch/distributed/_mesh_layout.py
Outdated
previous_span = size * stride | ||
return True | ||
|
||
def to_remapping_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function seems very useful, although I'm puzzled by what it actually does.
- the naming
to_remapping_tensor
is not very clear - the examples in the docstring didn't explain why they are the desired behaviors
In general as a reader I guess I'm missing contexts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remapping_tensor
is a name we agreed on during the discussion last Monday. but I am open to other proposals. Maybe can you elaborate more on what exact contexts you are missing here?
Basically what we wrote in the comment:
With this method, the cute layout serves as the backend of indices bookkeeping for the
mesh tensor when it comes to flatten, unflatten and slicing operations. The actual mesh
tensor still represents the actual device assignment and ranks.
which part sound not clear to you? Or are asking why we want to do this conversion? It is because we need an actual mesh tensor to specify device allocation and also for backend creation.
…yout" We want to refactor the internal bookkeeping of DeviceMesh so that: Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible. Concretely, in this PR, we do the following: 1. Use the _Layout to handle all index operations rather use a map to record mesh dims. 2. Replaced`flatten_name_to_root_dims` with `flatten_name_to_root_layout`. Basically one (size, stride) pair maps to one PG. One mesh_dim_name can only map to only layout. (More than one mesh_dim_name can map to the same layout). 3. Replaced `_get_slice_mesh_dims` with `_get_slice_mesh_layout`. 4. Use a new function `check_overlap` to check layout overlap. 5. Use a new function `to_remapping_tensor` to use layout ranks as indices when the mesh tensor is not representable as CuTe. The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor. With this refactoring we also enabled the slicing and flatten of non-contiguous dims of a device mesh which is hard to implement without cute layout. This is a continue of #161106 (original one got messed with EasyCLA) cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…yout" We want to refactor the internal bookkeeping of DeviceMesh so that: Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible. Concretely, in this PR, we do the following: 1. Use the _Layout to handle all index operations rather use a map to record mesh dims. 2. Replaced`flatten_name_to_root_dims` with `flatten_name_to_root_layout`. Basically one (size, stride) pair maps to one PG. One mesh_dim_name can only map to only layout. (More than one mesh_dim_name can map to the same layout). 3. Replaced `_get_slice_mesh_dims` with `_get_slice_mesh_layout`. 4. Use a new function `check_overlap` to check layout overlap. 5. Use a new function `to_remapping_tensor` to use layout ranks as indices when the mesh tensor is not representable as CuTe. The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor. With this refactoring we also enabled the slicing and flatten of non-contiguous dims of a device mesh which is hard to implement without cute layout. This is a continue of #161106 (original one got messed with EasyCLA) cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…yout" We want to refactor the internal bookkeeping of DeviceMesh so that: Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible. Concretely, in this PR, we do the following: 1. Use the _Layout to handle all index operations rather use a map to record mesh dims. 2. Replaced`flatten_name_to_root_dims` with `flatten_name_to_root_layout`. Basically one (size, stride) pair maps to one PG. One mesh_dim_name can only map to only layout. (More than one mesh_dim_name can map to the same layout). 3. Replaced `_get_slice_mesh_dims` with `_get_slice_mesh_layout`. 4. Use a new function `check_overlap` to check layout overlap. 5. Use a new function `to_remapping_tensor` to use layout ranks as indices when the mesh tensor is not representable as CuTe. The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor. With this refactoring we also enabled the slicing and flatten of non-contiguous dims of a device mesh which is hard to implement without cute layout. This is a continue of #161106 (original one got messed with EasyCLA) cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
torch/distributed/_mesh_layout.py
Outdated
|
||
def to_remapping_tensor( | ||
self, | ||
original_mesh_tensor: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current original_mesh_tensor
must have the same dimensionality as the layout, is this correct? In the future, if we would like to let users pass the device list/tensor into DeviceMesh
, can it be a one-dimensional list/tensor? So the access is layout->coordinates->indices-> device_list[indices].
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the future, if we would like to let users pass the device list/tensor into DeviceMesh
We already allow that today.
The current original_mesh_tensor must have the same dimensionality as the layout
Same world size otherwise when we do original_mesh_tensor[idx] will collapse.
can it be a one-dimensional list/tensor?
Yes, that's why I did a flatten in the end before indexing, it does not matter what shape original_mesh_tensor is, the final shape will be based on the given layout.
torch/distributed/_mesh_layout.py
Outdated
# This is important because the indices generated by the layout will be larger than the original mesh tensor | ||
# when the original mesh tensor does not contain all ranks in the world. So we need to scale the layout's stride | ||
# by world_size // mesh_tensor.numel() so that the indices generated by the layout will be within the range of | ||
# the original mesh tensor. | ||
if original_mesh_tensor.numel() != world_size: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, is this because the current DeviceMesh implementation so that layout covers more than the mesh tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, because we still needs to maintain the slicing of flattened tensor from root. That way we will need to keep the original layout not the layout derived from the submesh's mesh tensor.
return slice_mesh_dims | ||
return layout_sliced | ||
|
||
def _get_all_submeshes( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add TODO to these private APIs used by other components. This mostly represents our DeviceMesh should public these use cases in the future.
torch/distributed/device_mesh.py
Outdated
] = None, | ||
_init_backend: bool = True, | ||
_rank: Optional[int] = None, | ||
layout: Optional[_MeshLayout] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is a keyword argument, can we move it a bit earlier, at least earler than _init_backend
and _ranks
which look like private arguments or arguments that are going to be deprecated, lol.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, I want it to be the last one and private for now.
…yout" We want to refactor the internal bookkeeping of DeviceMesh so that: Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible. Concretely, in this PR, we do the following: 1. Use the _Layout to handle all index operations rather use a map to record mesh dims. 2. Replaced`flatten_name_to_root_dims` with `flatten_name_to_root_layout`. Basically one (size, stride) pair maps to one PG. One mesh_dim_name can only map to only layout. (More than one mesh_dim_name can map to the same layout). 3. Replaced `_get_slice_mesh_dims` with `_get_slice_mesh_layout`. 4. Use a new function `check_overlap` to check layout overlap. 5. Use a new function `to_remapping_tensor` to use layout ranks as indices when the mesh tensor is not representable as CuTe. The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor. With this refactoring we also enabled the slicing and flatten of non-contiguous dims of a device mesh which is hard to implement without cute layout. This is a continue of #161106 (original one got messed with EasyCLA) cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
@pytorchbot merge |
This PR has pending changes requested. Please address the comments and update the PR before merging. |
Edward already bounce back his request changes
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot revert -c ghfirst -m "caused internal test failure" |
@pytorchbot successfully started a revert job. Check the current status here. |
@fduwjj your PR has been successfully reverted. |
#163213)" This reverts commit b098514. Reverted #163213 on behalf of https://github.com/yangw-dev due to caused internal test failure ([comment](#163213 (comment)))
…yout" We want to refactor the internal bookkeeping of DeviceMesh so that: Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible. Concretely, in this PR, we do the following: 1. Use the `_MeshLayout` to handle all index operations rather use a map to record mesh dims. 2. Removed `flatten_name_to_root_dims`, because now we can directly get layout from a flattened device mesh. 3. Replaced `_get_slice_mesh_dims` with `_get_slice_mesh_layout`. 4. Use the newly added function `check_overlap` to check layout overlap. 5. Use a new function `to_remapping_tensor` to use layout ranks as indices when the mesh tensor is not representable as CuTe. The reason is that layout acts as a backend of mesh tensor bookkeeping (indexing indices), it needs to be used as indices for remap back to the mesh tensor for new DeviceMesh generation and backend init. For example, in the case of 2K to 4K, the underlying layout is (2K, 1) but the actual value of the mesh tensor is [2K, 2K+1, ....,]. While flattening, slicing, we need to remap the layout back to the new mesh tensor so it maps the actual device allocation. For example, in the 2K to 4K case, if the shape is (1K, 1K) with dim_names ("dp", "tp"). Then when slicing "tp", the mesh tensor should be (2K, 2K+1, ..., 3K-1) or (3K, 3K+1, ... 4K-1). not the global ranks generated from the layout. (1K, 1). Verified that loss curve is very close for DeepSeekV3 on torchtitan, note that exact same match is challenging because even if we run the baseline twice, the loss curve does not exactly match. <img width="1113" height="490" alt="image" src="https://github.com/user-attachments/assets/7877b5a4-337e-4ad8-b878-2378f4f0f38d" /> The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor. With this refactoring we also enabled the slicing and flatten of non-contiguous dims of a device mesh which is hard to implement without cute layout. This is a continue of #161106 (original one got messed with EasyCLA) cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
#164731) Fixes #163330 I tried to reproduce the bug with my 4-GPU setup (the original issue used 8 GPUs). I created several different test scenarios, trying to trigger the bug by: - creating two different device meshes - slicing them in various ways - checking if get_root_mesh() would get confused but the bug didn't show up! Everything worked correctly in `2.10`. I found that there was a massive refactoring of the `DeviceMesh` code (PR #163213) that landed on October 2nd. That PR completely rewrote how `DeviceMesh` tracks relationships between parent meshes and submeshes using. It seems like this refactoring fixed the bug! But I added a regression test to make sure it doesn't come back. The test (`test_get_root_mesh_multiple_independent_meshes`) does exactly what the bug report described: - creates two independent meshes - slices them both - verifies that each submesh correctly points back to its real parent - makes sure submeshes from mesh1 don't incorrectly claim mesh2 as their parent Pull Request resolved: #164731 Approved by: https://github.com/fduwjj
Stack from ghstack (oldest at bottom):
We want to refactor the internal bookkeeping of DeviceMesh so that:
Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible.
Concretely, in this PR, we do the following:
_MeshLayout
to handle all index operations rather use a map to record mesh dims.flatten_name_to_root_dims
, because now we can directly get layout from a flattened device mesh._get_slice_mesh_dims
with_get_slice_mesh_layout
.check_overlap
to check layout overlap.to_remapping_tensor
to use layout ranks as indices when the mesh tensor is not representable as CuTe. The reason is that layout acts as a backend of mesh tensor bookkeeping (indexing indices), it needs to be used as indices for remap back to the mesh tensor for new DeviceMesh generation and backend init. For example, in the case of 2K to 4K, the underlying layout is (2K, 1) but the actual value of the mesh tensor is [2K, 2K+1, ....,]. While flattening, slicing, we need to remap the layout back to the new mesh tensor so it maps the actual device allocation. For example, in the 2K to 4K case, if the shape is (1K, 1K) with dim_names ("dp", "tp"). Then when slicing "tp", the mesh tensor should be (2K, 2K+1, ..., 3K-1) or (3K, 3K+1, ... 4K-1). not the global ranks generated from the layout. (1K, 1).Verified that loss curve is very close for DeepSeekV3 on torchtitan, note that exact same match is challenging because even if we run the baseline twice, the loss curve does not exactly match.
The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor.
With this refactoring we also enabled the slicing and flatten of non-contiguous dims of a device mesh which is hard to implement without cute layout.
This is a continue of #161106 (original one got messed with EasyCLA)
cc @H-Huang @awgu @wanchaol @fegin @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci