2121import threading
2222from abc import ABC
2323from datetime import timedelta
24- from typing import TYPE_CHECKING , Dict , List , Optional , Tuple , Type , Union
24+ from typing import Any , TYPE_CHECKING , Dict , List , Optional , Tuple , Type , Union
2525
2626import torch
2727import torch .distributed as dist
@@ -852,6 +852,8 @@ def extend_device_mesh(
852852
853853
854854class ManagedDeviceMesh (DeviceMesh ):
855+ replicate_pg_singleton : Optional ["ManagedProcessGroup" ]
856+
855857 def __init__ (
856858 self ,
857859 mesh : Optional [DeviceMesh ],
@@ -880,6 +882,15 @@ def __init__(
880882 self ._flatten_mesh_list : Tuple [DeviceMesh , ...] = tuple ()
881883 self ._thread_id : Optional [int ] = None
882884
885+ def __getstate__ (self ) -> Dict [str , Any ]:
886+ state = self .__dict__ .copy ()
887+ state ["replicate_pg" ] = None
888+ return state
889+
890+ def __setstate__ (self , state : Dict [str , Any ]) -> None :
891+ self .__dict__ .update (state )
892+ self .replicate_pg = self .replicate_pg_singleton
893+
883894 def __getitem__ (self , mesh_dim_names : Union [str , Tuple [str , ...]]) -> DeviceMesh :
884895 if isinstance (mesh_dim_names , str ):
885896 if mesh_dim_names == self .replicate_dim_name :
@@ -897,13 +908,14 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh
897908 return self .mesh [mesh_dim_names ]
898909 else :
899910 assert isinstance (mesh_dim_names , tuple )
900- if self .replicate_dim_name in mesh_dim_names :
911+ if self .replicate_dim_name not in mesh_dim_names :
901912 assert self .mesh is not None
902913 return self .mesh [mesh_dim_names ]
903914 else :
904915 assert self .mesh is not None
916+ mesh_dim_names_wo_replicate = tuple (n for n in mesh_dim_names if n != self .replicate_dim_name )
905917 return ManagedDeviceMesh (
906- self .mesh [mesh_dim_names ],
918+ self .mesh [mesh_dim_names_wo_replicate ],
907919 mesh_dim_names ,
908920 self .replicate_pg ,
909921 mesh_dim_names .index (self .replicate_dim_name ),
@@ -938,14 +950,16 @@ def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
938950 return flatten_mesh
939951
940952 def size (self , mesh_dim : Optional [int ] = None ) -> int :
953+ replicate_pg_size = self .replicate_pg .size ()
954+ replicate_pg_size = 1 if replicate_pg_size == 0 else replicate_pg_size
941955 if mesh_dim is None :
942956 if self .mesh is None :
943- return self . replicate_pg . size ()
957+ return replicate_pg_size
944958 else :
945959 assert self .mesh is not None
946- return self .mesh .size () * self . replicate_pg . size ()
960+ return self .mesh .size () * replicate_pg_size
947961 elif mesh_dim == self .replicate_dim :
948- return self . replicate_pg . size ()
962+ return replicate_pg_size
949963 else :
950964 assert self .mesh is not None
951965 return self .mesh .size (self ._real_mesh_dim (mesh_dim ))
@@ -995,7 +1009,11 @@ def get_coordinate(self) -> Optional[List[int]]:
9951009 dimensions of the mesh. If this rank is not part of the mesh, return None.
9961010 """
9971011 assert self .mesh is not None
998- return self .mesh ._coordinate_on_dim if self .mesh ._coordinate_on_dim else None
1012+ ret = self .mesh ._coordinate_on_dim if self .mesh ._coordinate_on_dim else None
1013+ if ret :
1014+ ret = ret .copy ()
1015+ ret .insert (get_rank (self .replicate_pg ), self .replicate_dim )
1016+ return ret
9991017
10001018 def get_all_groups (self ) -> List [BaseProcessGroup ]:
10011019 raise NotImplementedError
@@ -1070,6 +1088,8 @@ def ft_init_device_mesh(
10701088 # the same backend has been registered.
10711089 replicate_pg .register (mesh_dim_names [replicate_dim ])
10721090
1091+ ManagedDeviceMesh .replicate_pg_singleton = replicate_pg
1092+
10731093 return ManagedDeviceMesh (
10741094 mesh = mesh ,
10751095 mesh_dim_names = mesh_dim_names ,
0 commit comments