Skip to content

Commit

Permalink
[DeviceMesh] Fix hash and eq not match (#123572)
Browse files Browse the repository at this point in the history
Fixes #121799

We fix DeviceMesh hash such that two mesh are considered equal if they have the same mesh and same parent_mesh.
Examples can be found here: #121799

Also need this to unblock #123394

Pull Request resolved: #123572
Approved by: https://github.com/xunnanxu, https://github.com/wanchaol, https://github.com/yoyoyocmu
  • Loading branch information
wz337 authored and pytorchmergebot committed May 16, 2024
1 parent 1876f0f commit 059b68f
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 17 deletions.
12 changes: 9 additions & 3 deletions test/distributed/_composable/fsdp/test_fully_shard_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,8 +685,9 @@ def test_process_group_init(self):

# Check the `from_group()` API for correctness
dp_mesh = DeviceMesh.from_group(dp_pg, "cuda")
# We only compare the mesh tensors instead of the DeviceMesh objects
# since mesh_dim_names attributes and parent mesh are different.
self.assertEqual(dp_mesh.mesh, ref_dp_mesh.mesh)
self.assertEqual(dp_mesh, ref_dp_mesh)
# self.assertFalse(hasattr(dp_mesh, "_coordinate_on_dim"))
self.assertEqual(dp_mesh._coordinate_on_dim, ref_dp_mesh._coordinate_on_dim)
self.assertEqual(dp_mesh._dim_group_infos, ref_dp_mesh._dim_group_infos)
Expand Down Expand Up @@ -721,8 +722,13 @@ def test_process_group_init(self):
loss.backward()
self.assertEqual(loss, ref_loss)
for param, ref_param in zip(model.parameters(), ref_model.parameters()):
self.assertEqual(param, ref_param)
self.assertEqual(param.grad, ref_param.grad)
# we cannot directly compare param and ref_param because their parent mesh is different.
self.assertEqual(param.to_local(), ref_param.to_local())
self.assertEqual(param.device_mesh.mesh, ref_param.device_mesh.mesh)
self.assertEqual(param.grad.to_local(), ref_param.grad.to_local())
self.assertEqual(
param.grad.device_mesh.mesh, ref_param.grad.device_mesh.mesh
)


class TestFullyShardHSDPBroadcast(FSDPTestMultiThread):
Expand Down
52 changes: 45 additions & 7 deletions test/distributed/test_device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def test_device_mesh_hash(self):
mesh_tensor_2d = torch.arange(8).reshape(4, 2)
mesh = DeviceMesh(self.device_type, mesh_tensor_2d)
mesh2 = DeviceMesh(self.device_type, mesh_tensor_2d)
self.assertNotEqual(hash(mesh), hash(mesh2))
self.assertEqual(hash(mesh), hash(mesh2))
mesh_tensor_3d = torch.arange(8).reshape(2, 2, 2)
mesh3 = DeviceMesh(self.device_type, mesh_tensor_3d)
self.assertNotEqual(hash(mesh), hash(mesh3))
Expand Down Expand Up @@ -269,6 +269,44 @@ def test_get_local_rank_3d(self):
expected_dp_rank = self.rank // 4
self.assertEqual(dp_rank, expected_dp_rank)

@with_comms
def test_device_mesh_parent_child_hash(self):
mesh_2d = init_device_mesh(
self.device_type, (2, self.world_size // 2), mesh_dim_names=("DP", "TP")
)

mesh_group_1 = torch.arange(0, self.world_size // 2)
mesh_group_2 = torch.arange(self.world_size // 2, self.world_size)
ep_mesh_1 = DeviceMesh(self.device_type, mesh_group_1)
ep_mesh_2 = DeviceMesh(self.device_type, mesh_group_2)
ep_mesh = ep_mesh_1 if self.rank < self.world_size // 2 else ep_mesh_2
# # ep_mesh is considered different from mesh_2d["TP"]
# # since mesh_2d["TP"] has a parent mesh while ep_mesh does not.
self.assertEqual(mesh_2d["TP"]._flatten_mesh_list, ep_mesh._flatten_mesh_list)
self.assertEqual(mesh_2d["TP"].mesh.shape, ep_mesh.mesh.shape)
self.assertEqual(mesh_2d["TP"].device_type, ep_mesh.device_type)
self.assertNotEqual(mesh_2d["TP"].mesh_dim_names, ep_mesh.mesh_dim_names)
self.assertEqual(mesh_2d["TP"]._thread_id, ep_mesh._thread_id)
self.assertNotEqual(mesh_2d["TP"]._parent_mesh, ep_mesh._parent_mesh)
self.assertNotEqual(hash(mesh_2d["TP"]), hash(ep_mesh))
self.assertNotEqual(mesh_2d["TP"], ep_mesh)

another_mesh_1 = DeviceMesh(self.device_type, mesh_group_1)
another_mesh_2 = DeviceMesh(self.device_type, mesh_group_2)
another_mesh = (
another_mesh_1 if self.rank < self.world_size // 2 else another_mesh_2
)
# another_mesh is considered the same as ep_mesh
# since they have the same mesh and no parent mesh.
self.assertEqual(ep_mesh._flatten_mesh_list, another_mesh._flatten_mesh_list)
self.assertEqual(ep_mesh.mesh.shape, another_mesh.mesh.shape)
self.assertEqual(ep_mesh.device_type, another_mesh.device_type)
self.assertEqual(ep_mesh.mesh_dim_names, another_mesh.mesh_dim_names)
self.assertEqual(ep_mesh._thread_id, another_mesh._thread_id)
self.assertEqual(ep_mesh._parent_mesh, another_mesh._parent_mesh)
self.assertEqual(hash(ep_mesh), hash(another_mesh))
self.assertEqual(ep_mesh, another_mesh)


class InitDeviceMeshTest(DTensorTestBase):
@property
Expand All @@ -278,20 +316,20 @@ def world_size(self):
@with_comms
def test_init_device_mesh(self):
mesh_shape = (2, 4)
ref_mesh = DeviceMesh(self.device_type, torch.arange(8).view(mesh_shape))
mesh_dim_names = ("DP", "TP")
ref_mesh = DeviceMesh(
self.device_type,
torch.arange(8).view(mesh_shape),
mesh_dim_names=mesh_dim_names,
)

# test init_device_mesh with mesh_dim_names
mesh_dim_names = ("DP", "TP")
mesh_2d = init_device_mesh(
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
)
self.assertEqual(mesh_2d, ref_mesh)
self.assertEqual(mesh_2d.mesh_dim_names, mesh_dim_names)

# test init_device_mesh without mesh_dim_names
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
self.assertEqual(mesh_2d, ref_mesh)

@with_comms
def test_raises_duplicate_mesh_dim_names(self):
with self.assertRaisesRegex(
Expand Down
35 changes: 28 additions & 7 deletions torch/distributed/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def create_child_mesh(
res_sub_mesh = sub_mesh

res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]] # type: ignore[possibly-undefined]
res_sub_mesh._parent_mesh = device_mesh
# Assign the current DeviceMesh as the parent of the child DeviceMesh.
# We need to update the mappings after the child mesh hash update.
self.child_to_parent_mapping[res_sub_mesh] = device_mesh
return res_sub_mesh

Expand Down Expand Up @@ -209,11 +211,12 @@ def __init__(
if isinstance(mesh, torch.Tensor)
else torch.tensor(mesh, dtype=torch.int)
)
self.mesh_dim_names = mesh_dim_names
self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None

# private field to pre-generate DeviceMesh's hash
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
self._hash = hash((self._flatten_mesh_list, self.mesh.shape, id(self)))
self._parent_mesh: Optional["DeviceMesh"] = None
self._thread_id = threading.get_ident()

# Skip process group initialization if xla device or init backend is False
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
Expand Down Expand Up @@ -334,17 +337,35 @@ def __repr__(self) -> str:
return device_mesh_repr

def __hash__(self):
# lazily compute hash
self._hash = getattr(self, "_hash", None)
if not self._hash:
self._hash = hash(
(
self._flatten_mesh_list,
self.mesh.shape,
self.device_type,
self.mesh_dim_names,
self._parent_mesh,
self._thread_id,
)
)
return self._hash

def __eq__(self, other: object) -> bool:
if not isinstance(other, DeviceMesh):
return False
if id(self.mesh) == id(other.mesh):
if id(self) == id(other):
return True
return (
self.mesh.shape == other.mesh.shape
and self._flatten_mesh_list == other._flatten_mesh_list
)
else:
return (
self._flatten_mesh_list == other._flatten_mesh_list
and self.mesh.shape == other.mesh.shape
and self.device_type == other.device_type
and self.mesh_dim_names == other.mesh_dim_names
and self._parent_mesh == other._parent_mesh
and self._thread_id == other._thread_id
)

def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh":
"""
Expand Down

0 comments on commit 059b68f

Please sign in to comment.