-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[DeviceMesh] Implement a device mesh concatenate api for submesh and SPMD use case #163358
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/fduwjj/206/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163358
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 3 Cancelled JobsAs of commit cb8dd0c with merge base d795fb2 ( NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…or submesh" Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users. One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…or submesh" Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users. One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
torch/distributed/device_mesh.py
Outdated
get_world_size(), | ||
) | ||
|
||
for mesh_nd in pg_ranks_by_dim: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?!?! Why do you need to do it for every mesh_nd? Is this because you're triggering comms to initialize PGs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so long story short, we need all ranks to call new_group
which is hidden very deep in the stack to initialize PGs. Otherwise the code will hang.
…or submesh" Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users. One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…or submesh" Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users. One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…or submesh" Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users. One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…or submesh" Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users. One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…or submesh" Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users. One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…or submesh" Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users. One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…or submesh" Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users. One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…or submesh" Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users. One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…ubmesh and SPMD use case" Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users. One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…ubmesh and SPMD use case" Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users. One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…ubmesh and SPMD use case" Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users. One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…ubmesh and SPMD use case" Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users. One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.
One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.
cc @H-Huang @awgu @wanchaol @fegin @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci