Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get rid of dim_groups attribute from DeviceMesh #103105

Closed
wants to merge 8 commits into from
8 changes: 2 additions & 6 deletions torch/distributed/_functional_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,18 +316,14 @@ def cast_listint(x):
elif isinstance(group, dt.DeviceMesh):
assert group.ndim == 1, "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
# TODO: it should run collective in the whole mesh instead of dim 0
mesh_pg = group.get_dim_groups()[0]
rankset = dist.get_process_group_ranks(mesh_pg)
tag, rankset = group._dim_group_infos[0]
group_size = len(rankset)
tag = tag or c10d._get_group_tag(mesh_pg)
elif isinstance(group, tuple):
if len(group) == 2 and isinstance(group[0], dt.DeviceMesh) and isinstance(group[1], int):
dmesh = group[0]
dim = group[1]
dim_group = dmesh.get_dim_groups()[dim]
rankset = dist.get_process_group_ranks(dim_group)
tag, rankset = dmesh._dim_group_infos[dim]
group_size = len(rankset)
tag = tag or c10d._get_group_tag(dim_group)
else:
raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
else:
Expand Down
61 changes: 41 additions & 20 deletions torch/distributed/_tensor/device_mesh.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
from typing import List, Optional, TYPE_CHECKING, Union
from typing import List, Optional, Tuple, TYPE_CHECKING, Union

import torch
import torch.distributed._functional_collectives as funcol

from torch.distributed.distributed_c10d import (
_find_pg_by_ranks_and_tag,
_get_default_group,
_get_group_tag,
all_gather,
all_to_all,
broadcast,
Expand Down Expand Up @@ -111,7 +113,7 @@ def __init__(
# process (we need to know if the current global rank is in the mesh or not)
self._get_or_create_default_group()
if _init_process_groups:
self._dim_groups = self._init_process_groups()
self._init_process_groups()

def _get_or_create_default_group(self):
default_initialized = is_initialized()
Expand Down Expand Up @@ -150,7 +152,6 @@ def _get_or_create_default_group(self):
return _get_default_group()

def _init_process_groups(self):
default_pg = _get_default_group()
# check mesh tensor validity
unique_mesh_values = self.mesh.unique(sorted=True)
if unique_mesh_values.numel() != self.mesh.numel():
Expand All @@ -169,15 +170,17 @@ def _init_process_groups(self):
f"has mesh {other_mesh}!"
)

# groups created by dimension, each dimension should have exact
# one valid process group per rank
dim_groups: List[ProcessGroup] = []
# group tag/ranks associated with each mesh dimension, each mesh dimension should
# have one sub-group per rank
dim_group_infos: List[Tuple[str, List[int]]] = []

if self.mesh.ndim == 1 and len(unique_mesh_values) == get_world_size():
# if the mesh is the same as world_pg, we just append the default
# pg to the first dim groups, as new_group cannot have the exact
# same ranks as world
dim_groups.append(default_pg)
dim_group_infos.append(
(_get_group_tag(_get_default_group()), list(range(get_world_size())))
)
else:
# create sub pgs base on the mesh argument specified
for dim in range(self.mesh.ndim):
Expand All @@ -193,16 +196,18 @@ def _init_process_groups(self):
# call new_group regardless of the current rank in the
# pg or not, it's required that all ranks participate
# in subgroup construction
new_subgroup = new_group(ranks=subgroup_ranks)
dim_group = new_group(ranks=subgroup_ranks)
# only add to dim_groups if the current rank in the subgroup
if self.get_rank() in subgroup_ranks:
if len(dim_groups) > dim:
if len(dim_group_infos) > dim:
raise RuntimeError(
f"Each device mesh dimension should get only one process group, but got {self.get_rank} "
f"in {subgroup_ranks}!"
)
dim_groups.append(new_subgroup)
return dim_groups
dim_group_infos.append(
(_get_group_tag(dim_group), subgroup_ranks)
)
self._dim_group_infos = dim_group_infos

def __enter__(self) -> "DeviceMesh":
# set this mesh as the current mesh in mesh env
Expand All @@ -227,10 +232,20 @@ def __eq__(self, other: object) -> bool:
return True
return self.mesh.equal(other.mesh)

def get_dim_groups(self) -> List[ProcessGroup]:
if not hasattr(self, "_dim_groups"):
def get_dim_groups(
self, mesh_dim: Optional[int] = None
) -> Union[ProcessGroup, List[ProcessGroup]]:
wanchaol marked this conversation as resolved.
Show resolved Hide resolved
if not hasattr(self, "_dim_group_infos"):
raise RuntimeError("DeviceMesh process groups not initialized!")
return self._dim_groups
if mesh_dim is not None:
return _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim])
else:
dim_groups = []
for mesh_dim in range(self.mesh.ndim):
dim_groups.append(
_find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim])
)
return dim_groups

def size(self, dim: Optional[int] = None) -> int:
return self.mesh.numel() if dim is None else self.mesh.size(dim)
Expand Down Expand Up @@ -279,9 +294,11 @@ def scatter(
# remove the check below once that is done.
if output.is_meta:
return None
dim_group = self._dim_groups[mesh_dim]
dim_group = self.get_dim_groups(mesh_dim)
assert isinstance(dim_group, ProcessGroup)
wanchaol marked this conversation as resolved.
Show resolved Hide resolved
# src need to be global rank
src_for_dim = 0

if dim_group is not GroupMember.WORLD:
src_for_dim = get_global_rank(dim_group, 0)

Expand Down Expand Up @@ -332,7 +349,8 @@ def broadcast(
# remove the check below once that is done.
if tensor.is_meta:
return None
dim_group = self._dim_groups[mesh_dim]
dim_group = self.get_dim_groups(mesh_dim)
assert isinstance(dim_group, ProcessGroup)
# src need to be global rank
src_for_dim = 0
if dim_group is not GroupMember.WORLD:
Expand Down Expand Up @@ -360,7 +378,8 @@ def all_gather(
Returns:
A :class:`AsyncCollectiveTensor` object
"""
dim_group = self._dim_groups[mesh_dim]
dim_group = self.get_dim_groups(mesh_dim)
assert isinstance(dim_group, ProcessGroup)
return funcol.all_gather_tensor(tensor, gather_dim=gather_dim, group=dim_group)

def all_reduce(
Expand Down Expand Up @@ -407,13 +426,15 @@ def reduce_scatter(
Returns:
A :class:`torch.Tensor` object
"""

dim_group = self.get_dim_groups(mesh_dim)
assert isinstance(dim_group, ProcessGroup)
if self.device_type == "cpu":
# cpu::gloo backend does not have reduce_scatter we fallback to do all_reduce
# + local chunk
logger.warning(
"ProcessGroupGloo does not support reduce_scatter, falling back with all reduce!"
)
dim_group = self._dim_groups[mesh_dim]
group_size = get_world_size(dim_group)
group_rank = get_rank(dim_group)
if scatter_dim != 0:
Expand All @@ -424,7 +445,6 @@ def reduce_scatter(
chunks = flat_tensor.chunk(group_size, dim=0)
scatter_tensor = chunks[group_rank]
else:
dim_group = self._dim_groups[mesh_dim]
scatter_tensor = funcol.reduce_scatter_tensor(
input, reduceOp=op.name, scatter_dim=scatter_dim, group=dim_group
)
Expand All @@ -439,7 +459,8 @@ def all_to_all(
mesh_dim: int = 0,
async_op: bool = False,
) -> Optional[Work]:
dim_group = self._dim_groups[mesh_dim]
dim_group = self.get_dim_groups(mesh_dim)
assert isinstance(dim_group, ProcessGroup)

work = None
# no direct dist.all_to_all support on 'gloo' so we manually do scatters
Expand Down
3 changes: 1 addition & 2 deletions torch/distributed/tensor/parallel/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,5 @@ def _create_1d_device_mesh(device_mesh: DeviceMesh, tp_mesh_dim: int = 0) -> Dev
if cur_rank in mesh_1d:
res_sub_mesh = sub_mesh

sub_pg = device_mesh.get_dim_groups()[tp_mesh_dim]
res_sub_mesh._dim_groups = [sub_pg]
res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[tp_mesh_dim]]
return res_sub_mesh
4 changes: 3 additions & 1 deletion torch/distributed/tensor/parallel/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ def _create_sharded_tensor_md_from_dt(
def _get_dt_pg(dt: DistributedTensor) -> c10d.ProcessGroup:
mesh = dt.device_mesh
assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
return mesh.get_dim_groups()[0]
dim_groups = mesh.get_dim_groups()
assert isinstance(dim_groups, list)
return dim_groups[0]


def _rewrite_spec_if_needed(
Expand Down