Skip to content

Conversation

fduwjj
Copy link
Contributor

@fduwjj fduwjj commented Sep 17, 2025

Stack from ghstack (oldest at bottom):

We want to refactor the internal bookkeeping of DeviceMesh so that:
Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible.

Concretely, in this PR, we do the following:

  1. Use the _MeshLayout to handle all index operations rather use a map to record mesh dims.
  2. Removed flatten_name_to_root_dims, because now we can directly get layout from a flattened device mesh.
  3. Replaced _get_slice_mesh_dims with _get_slice_mesh_layout.
  4. Use the newly added function check_overlap to check layout overlap.
  5. Use a new function to_remapping_tensor to use layout ranks as indices when the mesh tensor is not representable as CuTe. The reason is that layout acts as a backend of mesh tensor bookkeeping (indexing indices), it needs to be used as indices for remap back to the mesh tensor for new DeviceMesh generation and backend init. For example, in the case of 2K to 4K, the underlying layout is (2K, 1) but the actual value of the mesh tensor is [2K, 2K+1, ....,]. While flattening, slicing, we need to remap the layout back to the new mesh tensor so it maps the actual device allocation. For example, in the 2K to 4K case, if the shape is (1K, 1K) with dim_names ("dp", "tp"). Then when slicing "tp", the mesh tensor should be (2K, 2K+1, ..., 3K-1) or (3K, 3K+1, ... 4K-1). not the global ranks generated from the layout. (1K, 1).

Verified that loss curve is very close for DeepSeekV3 on torchtitan, note that exact same match is challenging because even if we run the baseline twice, the loss curve does not exactly match.

image

The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor.

With this refactoring we also enabled the slicing and flatten of non-contiguous dims of a device mesh which is hard to implement without cute layout.

This is a continue of #161106 (original one got messed with EasyCLA)

cc @H-Huang @awgu @wanchaol @fegin @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci

Copy link

pytorch-bot bot commented Sep 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163213

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a51ea5e with merge base f63d16c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Sep 17, 2025
fduwjj added a commit that referenced this pull request Sep 17, 2025
@fduwjj fduwjj added release notes: DeviceMesh ciflow/trunk Trigger trunk jobs on your pull request labels Sep 17, 2025
@fduwjj fduwjj requested review from ezyang, fegin and lw September 17, 2025 23:46
…yout"


We want to refactor the internal bookkeeping of DeviceMesh so that:
Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible.


Concretely, in this PR, we do the following:
1. Use the _Layout to handle all index operations rather use a map to record mesh dims.
2. Replaced`flatten_name_to_root_dims` with `flatten_name_to_root_layout`. Basically one (size, stride) pair maps to one PG. One mesh_dim_name can only map to only layout. (More than one mesh_dim_name can map to the same layout).
3. Replaced `_get_slice_mesh_dims` with `_get_slice_mesh_layout`.
4. Use a new function `check_overlap` to check layout overlap.
5. Use a new function `to_remapping_tensor` to use layout ranks as indices when the mesh tensor is not representable as CuTe.

The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor.

This is a continue of #161106 (originally messed with EasyCLA)

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Sep 17, 2025
…yout"


We want to refactor the internal bookkeeping of DeviceMesh so that:
Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible.


Concretely, in this PR, we do the following:
1. Use the _Layout to handle all index operations rather use a map to record mesh dims.
2. Replaced`flatten_name_to_root_dims` with `flatten_name_to_root_layout`. Basically one (size, stride) pair maps to one PG. One mesh_dim_name can only map to only layout. (More than one mesh_dim_name can map to the same layout).
3. Replaced `_get_slice_mesh_dims` with `_get_slice_mesh_layout`.
4. Use a new function `check_overlap` to check layout overlap.
5. Use a new function `to_remapping_tensor` to use layout ranks as indices when the mesh tensor is not representable as CuTe.

The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor.

With this refactoring we also enabled the slicing and flatten of non-contiguous dims of a device mesh which is hard to implement without cute layout.

This is a continue of #161106 (originally messed with EasyCLA)

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Sep 18, 2025
return slice_mesh_dims
return layout_sliced

def _get_all_submeshes(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function ever used, other in test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is used internally by shampoo. cc: @wz337

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add TODO to these private APIs used by other components. This mostly represents our DeviceMesh should public these use cases in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we get clarity on which context this is needed in and whether the same can be achieved in other ways?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lw my understanding is that inside Shampoo implementation it needs to know all other sub-meshes as well. cc: @wz337

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's what the code usage in Shampoo looks like:

        # NOTE: We get all submeshes along the "replicate" dimension, then pick out
        # the sub-mesh that the optimizer state is assigned to.
        #
        # For the example above, this would give me submeshes [[3, 27], [11, 35], [19, 43]].
        # Note that the group source rank must belong to {0, 1, 2} in this case.
        # Suppose the group_source_rank = 1, then this would get the submesh [11, 35].
        replicate_submesh = _mesh_resources._get_all_submeshes(
            device_mesh_2d, "replicate"
        )[group_source_rank]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like a legitimate use case although I don't know exactly how to spell this. It feels like what is happening here is they want to do a submesh slice on replicate, but if they do it directly as _mesh_resources["replicate"] they will get the submesh for the CURRENT rank. But instead, they want to run it for group_source_rank. I feel this should be directly supportable, but it does interact with the current idea that a submesh is "only" a Layout, but it is also a particular coordinate picked out of the complement.

previous_span = size * stride
return True

def to_remapping_tensor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function seems very useful, although I'm puzzled by what it actually does.

  1. the naming to_remapping_tensor is not very clear
  2. the examples in the docstring didn't explain why they are the desired behaviors

In general as a reader I guess I'm missing contexts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remapping_tensor is a name we agreed on during the discussion last Monday. but I am open to other proposals. Maybe can you elaborate more on what exact contexts you are missing here?

Basically what we wrote in the comment:

With this method, the cute layout serves as the backend of indices bookkeeping for the
mesh tensor when it comes to flatten, unflatten and slicing operations. The actual mesh
tensor still represents the actual device assignment and ranks.

which part sound not clear to you? Or are asking why we want to do this conversion? It is because we need an actual mesh tensor to specify device allocation and also for backend creation.

…yout"


We want to refactor the internal bookkeeping of DeviceMesh so that:
Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible.


Concretely, in this PR, we do the following:
1. Use the _Layout to handle all index operations rather use a map to record mesh dims.
2. Replaced`flatten_name_to_root_dims` with `flatten_name_to_root_layout`. Basically one (size, stride) pair maps to one PG. One mesh_dim_name can only map to only layout. (More than one mesh_dim_name can map to the same layout).
3. Replaced `_get_slice_mesh_dims` with `_get_slice_mesh_layout`.
4. Use a new function `check_overlap` to check layout overlap.
5. Use a new function `to_remapping_tensor` to use layout ranks as indices when the mesh tensor is not representable as CuTe.

The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor.

With this refactoring we also enabled the slicing and flatten of non-contiguous dims of a device mesh which is hard to implement without cute layout.

This is a continue of #161106 (original one got messed with EasyCLA)

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Sep 18, 2025
…yout"


We want to refactor the internal bookkeeping of DeviceMesh so that:
Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible.


Concretely, in this PR, we do the following:
1. Use the _Layout to handle all index operations rather use a map to record mesh dims.
2. Replaced`flatten_name_to_root_dims` with `flatten_name_to_root_layout`. Basically one (size, stride) pair maps to one PG. One mesh_dim_name can only map to only layout. (More than one mesh_dim_name can map to the same layout).
3. Replaced `_get_slice_mesh_dims` with `_get_slice_mesh_layout`.
4. Use a new function `check_overlap` to check layout overlap.
5. Use a new function `to_remapping_tensor` to use layout ranks as indices when the mesh tensor is not representable as CuTe.

The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor.

With this refactoring we also enabled the slicing and flatten of non-contiguous dims of a device mesh which is hard to implement without cute layout.

This is a continue of #161106 (original one got messed with EasyCLA)

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Sep 18, 2025
…yout"


We want to refactor the internal bookkeeping of DeviceMesh so that:
Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible.


Concretely, in this PR, we do the following:
1. Use the _Layout to handle all index operations rather use a map to record mesh dims.
2. Replaced`flatten_name_to_root_dims` with `flatten_name_to_root_layout`. Basically one (size, stride) pair maps to one PG. One mesh_dim_name can only map to only layout. (More than one mesh_dim_name can map to the same layout).
3. Replaced `_get_slice_mesh_dims` with `_get_slice_mesh_layout`.
4. Use a new function `check_overlap` to check layout overlap.
5. Use a new function `to_remapping_tensor` to use layout ranks as indices when the mesh tensor is not representable as CuTe.

The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor.

With this refactoring we also enabled the slicing and flatten of non-contiguous dims of a device mesh which is hard to implement without cute layout.

This is a continue of #161106 (original one got messed with EasyCLA)

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Sep 18, 2025
@fduwjj fduwjj requested a review from tianyu-l September 19, 2025 03:37

def to_remapping_tensor(
self,
original_mesh_tensor: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current original_mesh_tensor must have the same dimensionality as the layout, is this correct? In the future, if we would like to let users pass the device list/tensor into DeviceMesh, can it be a one-dimensional list/tensor? So the access is layout->coordinates->indices-> device_list[indices].

Copy link
Contributor Author

@fduwjj fduwjj Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future, if we would like to let users pass the device list/tensor into DeviceMesh

We already allow that today.

The current original_mesh_tensor must have the same dimensionality as the layout

Same world size otherwise when we do original_mesh_tensor[idx] will collapse.

can it be a one-dimensional list/tensor?

Yes, that's why I did a flatten in the end before indexing, it does not matter what shape original_mesh_tensor is, the final shape will be based on the given layout.

Comment on lines 351 to 355
# This is important because the indices generated by the layout will be larger than the original mesh tensor
# when the original mesh tensor does not contain all ranks in the world. So we need to scale the layout's stride
# by world_size // mesh_tensor.numel() so that the indices generated by the layout will be within the range of
# the original mesh tensor.
if original_mesh_tensor.numel() != world_size:
Copy link
Contributor

@fegin fegin Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, is this because the current DeviceMesh implementation so that layout covers more than the mesh tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, because we still needs to maintain the slicing of flattened tensor from root. That way we will need to keep the original layout not the layout derived from the submesh's mesh tensor.

return slice_mesh_dims
return layout_sliced

def _get_all_submeshes(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add TODO to these private APIs used by other components. This mostly represents our DeviceMesh should public these use cases in the future.

] = None,
_init_backend: bool = True,
_rank: Optional[int] = None,
layout: Optional[_MeshLayout] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a keyword argument, can we move it a bit earlier, at least earler than _init_backend and _ranks which look like private arguments or arguments that are going to be deprecated, lol.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, I want it to be the last one and private for now.

…yout"


We want to refactor the internal bookkeeping of DeviceMesh so that:
Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible.


Concretely, in this PR, we do the following:
1. Use the _Layout to handle all index operations rather use a map to record mesh dims.
2. Replaced`flatten_name_to_root_dims` with `flatten_name_to_root_layout`. Basically one (size, stride) pair maps to one PG. One mesh_dim_name can only map to only layout. (More than one mesh_dim_name can map to the same layout).
3. Replaced `_get_slice_mesh_dims` with `_get_slice_mesh_layout`.
4. Use a new function `check_overlap` to check layout overlap.
5. Use a new function `to_remapping_tensor` to use layout ranks as indices when the mesh tensor is not representable as CuTe.

The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor.

With this refactoring we also enabled the slicing and flatten of non-contiguous dims of a device mesh which is hard to implement without cute layout.

This is a continue of #161106 (original one got messed with EasyCLA)

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Sep 19, 2025
fduwjj added a commit that referenced this pull request Oct 1, 2025
@fduwjj
Copy link
Contributor Author

fduwjj commented Oct 2, 2025

@pytorchbot merge

Copy link

pytorch-bot bot commented Oct 2, 2025

This PR has pending changes requested. Please address the comments and update the PR before merging.

@fduwjj fduwjj dismissed ezyang’s stale review October 2, 2025 15:34

Edward already bounce back his request changes

@fduwjj
Copy link
Contributor Author

fduwjj commented Oct 2, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@yangw-dev
Copy link
Contributor

@pytorchbot revert -c ghfirst -m "caused internal test failure"

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@fduwjj your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Oct 2, 2025
#163213)"

This reverts commit b098514.

Reverted #163213 on behalf of https://github.com/yangw-dev due to caused internal test failure ([comment](#163213 (comment)))
@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Oct 2, 2025
…yout"


We want to refactor the internal bookkeeping of DeviceMesh so that:
Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible.


Concretely, in this PR, we do the following:
1. Use the `_MeshLayout` to handle all index operations rather use a map to record mesh dims.
2. Removed `flatten_name_to_root_dims`, because now we can directly get layout from a flattened device mesh.
3. Replaced `_get_slice_mesh_dims` with `_get_slice_mesh_layout`.
4. Use the newly added function `check_overlap` to check layout overlap.
5. Use a new function `to_remapping_tensor` to use layout ranks as indices when the mesh tensor is not representable as CuTe. The reason is that layout acts as a backend of mesh tensor bookkeeping (indexing indices), it needs to be used as indices for remap back to the mesh tensor for new DeviceMesh generation and backend init. For example, in the case of 2K to 4K, the underlying layout is (2K, 1) but the actual value of the mesh tensor is [2K, 2K+1, ....,]. While flattening, slicing, we need to remap the layout back to the new mesh tensor so it maps the actual device allocation. For example, in the 2K to 4K case, if the shape is (1K, 1K) with dim_names ("dp", "tp"). Then when slicing "tp", the mesh tensor should be (2K, 2K+1, ..., 3K-1) or (3K, 3K+1, ... 4K-1). not the global ranks generated from the layout. (1K, 1).

Verified that loss curve is very close for DeepSeekV3 on torchtitan, note that exact same match is challenging because even if we run the baseline twice, the loss curve does not exactly match.

<img width="1113" height="490" alt="image" src="https://github.com/user-attachments/assets/7877b5a4-337e-4ad8-b878-2378f4f0f38d" />


The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor.

With this refactoring we also enabled the slicing and flatten of non-contiguous dims of a device mesh which is hard to implement without cute layout.

This is a continue of #161106 (original one got messed with EasyCLA)

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Oct 3, 2025
@fduwjj
Copy link
Contributor Author

fduwjj commented Oct 3, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 6, 2025
#164731)

Fixes #163330

I tried to reproduce the bug with my 4-GPU setup (the original issue used 8 GPUs). I created several different test scenarios, trying to trigger the bug by:
- creating two different device meshes
- slicing them in various ways
- checking if get_root_mesh() would get confused

but the bug didn't show up! Everything worked correctly in `2.10`. I found that there was a massive refactoring of the `DeviceMesh` code (PR #163213) that landed on October 2nd. That PR completely rewrote how `DeviceMesh` tracks relationships between parent meshes and submeshes using. It seems like this refactoring fixed the bug! But I added a regression test to make sure it doesn't come back. The test (`test_get_root_mesh_multiple_independent_meshes`) does exactly what the bug report described:
  - creates two independent meshes
  - slices them both
  - verifies that each submesh correctly points back to its real parent
  - makes sure submeshes from mesh1 don't incorrectly claim mesh2 as their parent

Pull Request resolved: #164731
Approved by: https://github.com/fduwjj
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: DeviceMesh Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants