From 24e0f27711a5851293f40e91468702f012ce0c2e Mon Sep 17 00:00:00 2001 From: fduwjj Date: Tue, 21 Oct 2025 11:43:31 -0700 Subject: [PATCH 1/3] [WIP][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map [ghstack-poisoned] --- test/distributed/test_device_mesh.py | 3 + torch/distributed/device_mesh.py | 86 +++++++++++++++++++++++----- 2 files changed, 74 insertions(+), 15 deletions(-) diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index a0de1b13c6161..9c17425056df4 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -1000,6 +1000,9 @@ def test_unflatten_mesh_3d(self): ) non_ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "cp", "tp")) ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "ep", "ep_tp")) + # test pg caching when unflatten into same layout. + self.assertEqual(non_ep_mesh["dp"].get_group(), ep_mesh["dp"].get_group()) + self.assertEqual(non_ep_mesh["tp"].get_group(), ep_mesh["ep_tp"].get_group()) self.assertEqual(non_ep_mesh["cp"].mesh, ep_mesh["ep"].mesh) self.assertEqual(non_ep_mesh["tp"].mesh, ep_mesh["ep_tp"].mesh) mesh_3d = global_mesh._unflatten(0, (4, 2, 1), ("dp", "cp", "tp")) diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 949e04eb5f5f3..afdbdcefa6b22 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -123,6 +123,38 @@ def _get_device_handle(device_type: str = "cuda"): """ return getattr(torch, device_type, None) + class _SharedState: + """ + This class is used to store the shared state of the DeviceMesh. + """ + + _rank_map: torch.Tensor + _root_mesh: "DeviceMesh" + _backend_cache: dict[_MeshLayout, str] + + def __init__(self, rank_map: torch.Tensor, root_mesh: "DeviceMesh") -> None: + self._rank_map = rank_map + self._root_mesh = root_mesh + self._backend_cache: dict[_MeshLayout, str] = {} + self._lock = threading.Lock() + + def get_rank_map(self) -> torch.Tensor: + with self._lock: + return self._rank_map + + def get_root_mesh(self) -> "DeviceMesh": + with self._lock: + return self._root_mesh + + def update_backend_cache(self, layout: _MeshLayout, backend: str) -> None: + with self._lock: + if layout not in self._backend_cache: + self._backend_cache[layout] = backend + + def get_backend_from_cache(self, layout: _MeshLayout) -> Optional[str]: + with self._lock: + return self._backend_cache.get(layout, None) + class DeviceMesh: """ DeviceMesh represents a mesh of devices, where layout of devices could be @@ -172,12 +204,13 @@ class DeviceMesh: """ _device_type: str - _rank_map: torch.Tensor + _rank_map: torch.Tensor # TODO: remove _mesh_dim_names: Optional[tuple[str, ...]] _layout: _MeshLayout - _root_mesh: Optional["DeviceMesh"] = None + _root_mesh: Optional["DeviceMesh"] = None # TODO: remove # Record flatten mesh name to its flattened mesh in root mesh. _flatten_mapping: dict[str, "DeviceMesh"] + _shared_state: _SharedState def __init__( self, @@ -191,6 +224,7 @@ def __init__( _layout: Optional[_MeshLayout] = None, _rank_map: Optional[torch.Tensor] = None, _root_mesh: Optional["DeviceMesh"] = None, + _shared_state: Optional[_SharedState] = None, ) -> None: if mesh is not None: if _layout is not None or _rank_map is not None: @@ -236,6 +270,11 @@ def __init__( f"but got {len(backend_override)} and {len(self._layout)}." ) + if _shared_state is None: + self._shared_state = _SharedState(self._rank_map, self) + else: + self._shared_state = _shared_state + # Skip process group initialization if xla device or init backend is False # TODO(yeounoh) implement DeviceMesh backend and register XLA backend. self._thread_id = None @@ -248,6 +287,7 @@ def __init__( self._dim_group_names = self._init_process_groups( self._layout, self._rank_map, + self._shared_state, self._mesh_dim_names, backend_override, ) @@ -350,6 +390,7 @@ def _setup_world_group_and_device(self): def _init_process_groups( layout: _MeshLayout, rank_map: torch.Tensor, + shared_state: _SharedState, mesh_dim_names: Optional[tuple[str, ...]], backend_override: tuple[BackendConfig, ...], ) -> list[str]: @@ -364,26 +405,37 @@ def _init_process_groups( and layout.numel() == get_world_size() and backend_override[0] == (None, None) ): - # Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`. - # Otherwise, create new pg. - ranks = list(range(get_world_size())) - dim_group = ( - new_group( - backend="cpu:gloo,cuda:nccl", - ranks=ranks, - group_desc="mesh_default", + backend_cache = shared_state.get_backend_from_cache(layout) + if backend_cache is not None: + dim_group_names.append(backend_cache) + else: + # Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`. + # Otherwise, create new pg. + ranks = list(range(get_world_size())) + dim_group = ( + new_group( + backend="cpu:gloo,cuda:nccl", + ranks=ranks, + group_desc="mesh_default", + ) + if torch.cuda.is_available() + and get_backend(default_group) == "gloo" + else default_group ) - if torch.cuda.is_available() - and get_backend(default_group) == "gloo" - else default_group - ) - dim_group_names.append(dim_group.group_name) + shared_state.update_backend_cache(layout, dim_group.group_name) + dim_group_names.append(dim_group.group_name) else: # create sub pgs base on the mesh argument specified for dim in range(len(layout)): # swap the current dim to the last dim # then reshape to flatten out other dims pg_ranks_by_dim = layout[dim].nest().remap_to_tensor(rank_map) + # PG Cache check here. + backend_cache = shared_state.get_backend_from_cache(layout[dim]) + if backend_cache is not None: + dim_group_names.append(backend_cache) + continue + backend, pg_options = backend_override[dim] # We need to explicitly pass in timeout when specified in option, otherwise # the default timeout will be used to override the timeout set in option. @@ -460,6 +512,9 @@ def _init_process_groups( f"Each device mesh dimension should get only one process group, but got {get_rank()} " f"in {subgroup_ranks}!" ) + backend_cache = shared_state.update_backend_cache( + layout[dim], dim_group.group_name + ) dim_group_names.append(dim_group.group_name) # type: ignore[union-attr] return dim_group_names @@ -1105,6 +1160,7 @@ def _create_unflatten_mesh( dim_group_names[dim : dim + 1] = self._init_process_groups( partial_layout, root_mesh._rank_map, + root_mesh._shared_state, mesh_dim_names, backend_override, ) From 602f1f42c2260cc914882d74da6eecee24baeccc Mon Sep 17 00:00:00 2001 From: fduwjj Date: Tue, 21 Oct 2025 15:43:22 -0700 Subject: [PATCH 2/3] Update on "[WIP][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map" cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci [ghstack-poisoned] --- torch/distributed/device_mesh.py | 59 ++++++++++++++++---------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index afdbdcefa6b22..89d67b728d6c3 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -146,6 +146,10 @@ def get_root_mesh(self) -> "DeviceMesh": with self._lock: return self._root_mesh + def set_root_mesh(self, root_mesh: "DeviceMesh") -> None: + with self._lock: + self._root_mesh = root_mesh + def update_backend_cache(self, layout: _MeshLayout, backend: str) -> None: with self._lock: if layout not in self._backend_cache: @@ -222,14 +226,12 @@ def __init__( _init_backend: bool = True, _rank: Optional[int] = None, _layout: Optional[_MeshLayout] = None, - _rank_map: Optional[torch.Tensor] = None, - _root_mesh: Optional["DeviceMesh"] = None, _shared_state: Optional[_SharedState] = None, ) -> None: if mesh is not None: - if _layout is not None or _rank_map is not None: + if _layout is not None or _shared_state is not None: raise TypeError( - "Cannot provide _layout and/or _rank_map if passing explicit mesh" + "Cannot provide _layout and/or _shared_state if passing explicit mesh" ) if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu": raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}") @@ -239,28 +241,29 @@ def __init__( else torch.tensor(mesh, device="cpu", dtype=torch.int) ) _layout = _MeshLayout(mesh_tensor.size(), mesh_tensor.stride()) - _rank_map = mesh_tensor.flatten() + rank_map = mesh_tensor.flatten() else: - if _layout is None or _rank_map is None: + if _layout is None or _shared_state is None: raise TypeError( "The mesh argument is required except for PRIVATE USAGE ONLY!" ) + if _shared_state.get_root_mesh() is None: + _shared_state.set_root_mesh(self) + rank_map = _shared_state.get_rank_map() assert _layout.check_non_overlap(), ( "Please use a non-overlapping layout when creating a DeviceMesh." ) - assert _rank_map.ndim == 1, "The rank map must be 1-dimensional" - assert _rank_map.is_contiguous(), "The rank map must be contiguous" - assert _rank_map.numel() >= _layout.cosize(), ( - f"The rank map contains {_rank_map.numel()} element, " + assert rank_map.ndim == 1, "The rank map must be 1-dimensional" + assert rank_map.is_contiguous(), "The rank map must be contiguous" + assert rank_map.numel() >= _layout.cosize(), ( + f"The rank map contains {rank_map.numel()} element, " f"which isn't large enough for layout {_layout}" ) self._device_type = device_type self._layout = _layout - self._rank_map = _rank_map self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None - self._root_mesh = _root_mesh if backend_override is None: backend_override = ((None, None),) * len(self._layout) @@ -271,7 +274,7 @@ def __init__( ) if _shared_state is None: - self._shared_state = _SharedState(self._rank_map, self) + self._shared_state = _SharedState(rank_map, self) else: self._shared_state = _shared_state @@ -286,7 +289,6 @@ def __init__( self._setup_world_group_and_device() self._dim_group_names = self._init_process_groups( self._layout, - self._rank_map, self._shared_state, self._mesh_dim_names, backend_override, @@ -307,7 +309,7 @@ def __init__( ) # private field to pre-generate DeviceMesh's hash - self._flatten_rank_map = tuple(self._rank_map.tolist()) + self._flatten_rank_map = tuple(rank_map.tolist()) # Initialize instance-specific flatten mapping self._flatten_mapping = {} @@ -319,7 +321,7 @@ def device_type(self) -> str: @property def mesh(self) -> torch.Tensor: """Returns the tensor representing the layout of devices.""" - full_mesh = self._layout.remap_to_tensor(self._rank_map) + full_mesh = self._layout.remap_to_tensor(self._shared_state.get_rank_map()) if full_mesh.size(0) == 1: return full_mesh[0] my_coords = (full_mesh == get_rank()).nonzero() @@ -389,7 +391,6 @@ def _setup_world_group_and_device(self): @staticmethod def _init_process_groups( layout: _MeshLayout, - rank_map: torch.Tensor, shared_state: _SharedState, mesh_dim_names: Optional[tuple[str, ...]], backend_override: tuple[BackendConfig, ...], @@ -429,7 +430,9 @@ def _init_process_groups( for dim in range(len(layout)): # swap the current dim to the last dim # then reshape to flatten out other dims - pg_ranks_by_dim = layout[dim].nest().remap_to_tensor(rank_map) + pg_ranks_by_dim = ( + layout[dim].nest().remap_to_tensor(shared_state.get_rank_map()) + ) # PG Cache check here. backend_cache = shared_state.get_backend_from_cache(layout[dim]) if backend_cache is not None: @@ -519,7 +522,7 @@ def _init_process_groups( return dim_group_names def _get_root_mesh(self) -> "DeviceMesh": - return self._root_mesh if self._root_mesh else self + return self._shared_state.get_root_mesh() def __enter__(self) -> "DeviceMesh": # set this mesh as the current mesh in mesh env @@ -720,10 +723,9 @@ def _create_sub_mesh( res_submesh = DeviceMesh( self._device_type, _layout=layout, - _rank_map=root_mesh._rank_map, mesh_dim_names=submesh_dim_names, - _root_mesh=root_mesh, _init_backend=False, + _shared_state=root_mesh._shared_state, ) res_submesh._dim_group_names = slice_dim_group_name return res_submesh @@ -770,10 +772,9 @@ def _create_flatten_mesh( res_flattened_mesh = DeviceMesh( root_mesh._device_type, _layout=flattened_mesh_layout, - _rank_map=root_mesh._rank_map, mesh_dim_names=(mesh_dim_name,), - _root_mesh=root_mesh, backend_override=(backend_override,), + _shared_state=root_mesh._shared_state, ) root_mesh._flatten_mapping[mesh_dim_name] = res_flattened_mesh @@ -902,7 +903,7 @@ def _get_all_submeshes(self, mesh_dim_name: str) -> list["DeviceMesh"]: """ mesh_dim = self._get_mesh_dim_by_name(mesh_dim_name) layout = self._layout[mesh_dim] - pg_ranks_by_dim = layout.remap_to_tensor(self._rank_map) + pg_ranks_by_dim = layout.remap_to_tensor(self._shared_state.get_rank_map()) cur_rank = self.get_rank() res_submeshes = [] for mesh_1d in pg_ranks_by_dim: @@ -1145,10 +1146,9 @@ def _create_unflatten_mesh( res_mesh = DeviceMesh( self.device_type, _layout=unflattened_layout, - _rank_map=root_mesh._rank_map, mesh_dim_names=tuple(unflattened_mesh_dim_names), - _root_mesh=root_mesh, _init_backend=False, + _shared_state=root_mesh._shared_state, ) # If original mesh has initiated its backend, we need to initialize the backend @@ -1159,7 +1159,6 @@ def _create_unflatten_mesh( dim_group_names = self._dim_group_names.copy() dim_group_names[dim : dim + 1] = self._init_process_groups( partial_layout, - root_mesh._rank_map, root_mesh._shared_state, mesh_dim_names, backend_override, @@ -1259,10 +1258,9 @@ def _concatenate(device_mesh_list: list["DeviceMesh"]) -> "DeviceMesh": res_mesh = DeviceMesh( device_mesh_list[0].device_type, _layout=concat_mesh_layout, - _rank_map=device_mesh_list[0]._rank_map, mesh_dim_names=tuple(concat_dim_names), - _root_mesh=device_mesh_list[0]._get_root_mesh(), _init_backend=False, + _shared_state=device_mesh_list[0]._shared_state, ) res_mesh._dim_group_names = concat_dim_group_name return res_mesh @@ -1391,12 +1389,13 @@ def init_device_mesh( # external device type has been set to be (e.g. meta) with torch.device("cpu"): rank_map = torch.arange(layout.numel(), dtype=torch.int) + shared_state = _SharedState(rank_map, None) device_mesh = DeviceMesh( device_type=device_type, _layout=layout, - _rank_map=rank_map, mesh_dim_names=mesh_dim_names, backend_override=backend_override_tuple, + _shared_state=shared_state, ) return device_mesh From e186bca40f30017346c4fad32bde522fc60e2927 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Tue, 21 Oct 2025 16:07:13 -0700 Subject: [PATCH 3/3] Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map" We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci [ghstack-poisoned] --- torch/distributed/device_mesh.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 89d67b728d6c3..a4923ddf4f812 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -129,10 +129,12 @@ class _SharedState: """ _rank_map: torch.Tensor - _root_mesh: "DeviceMesh" + _root_mesh: Optional["DeviceMesh"] _backend_cache: dict[_MeshLayout, str] - def __init__(self, rank_map: torch.Tensor, root_mesh: "DeviceMesh") -> None: + def __init__( + self, rank_map: torch.Tensor, root_mesh: Optional["DeviceMesh"] = None + ) -> None: self._rank_map = rank_map self._root_mesh = root_mesh self._backend_cache: dict[_MeshLayout, str] = {} @@ -142,7 +144,7 @@ def get_rank_map(self) -> torch.Tensor: with self._lock: return self._rank_map - def get_root_mesh(self) -> "DeviceMesh": + def get_root_mesh(self) -> Optional["DeviceMesh"]: with self._lock: return self._root_mesh @@ -515,14 +517,15 @@ def _init_process_groups( f"Each device mesh dimension should get only one process group, but got {get_rank()} " f"in {subgroup_ranks}!" ) - backend_cache = shared_state.update_backend_cache( - layout[dim], dim_group.group_name + shared_state.update_backend_cache( + layout[dim], + dim_group.group_name, # type: ignore[union-attr] ) dim_group_names.append(dim_group.group_name) # type: ignore[union-attr] return dim_group_names def _get_root_mesh(self) -> "DeviceMesh": - return self._shared_state.get_root_mesh() + return not_none(self._shared_state.get_root_mesh()) def __enter__(self) -> "DeviceMesh": # set this mesh as the current mesh in mesh env @@ -1389,7 +1392,7 @@ def init_device_mesh( # external device type has been set to be (e.g. meta) with torch.device("cpu"): rank_map = torch.arange(layout.numel(), dtype=torch.int) - shared_state = _SharedState(rank_map, None) + shared_state = _SharedState(rank_map) device_mesh = DeviceMesh( device_type=device_type, _layout=layout,