Skip to content

Commit

Permalink
Support unordered sharding spec for partial replication (#5316)
Browse files Browse the repository at this point in the history
* Suport unordered sharding spec for partial replication

* add 4d test

* handle 2d tensor with 2d mesh case

* refactoring
  • Loading branch information
JackCaoG committed Jul 21, 2023
1 parent 080fdcf commit aac03da
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 14 deletions.
89 changes: 89 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,95 @@ def test_mark_sharding_partial(self):
actual = (xt1 @ t2).cpu()
self.assertTrue(torch.allclose(expected, actual))

def test_mark_sharding_not_ordered_partial_3d(self):
device = xm.xla_device()
t1 = torch.randn(8, 16, 32).to(device)
t2 = torch.randn(8, 16, 32).to(device)
# Somehow the eager cpu result is different from the xla result.
expected = t1 + t2
# To re-materialize t1 and t2.
xm.mark_step()
xm.wait_device_ops()
expected = expected.cpu()

# Shard along two axes if four or more devices are available
z_dim = 2 if self.n_devices >= 4 else 1
mesh = self._get_mesh((z_dim, 1, self.n_devices // z_dim))

# Expect local shard size to be [8, 16 / z_dim, 32]
xt1 = xs.mark_sharding(t1, mesh, (1, 0, None))

for local_shard in xt1.local_shards:
self.assertEqual(local_shard.data.size()[0], 8)
self.assertEqual(local_shard.data.size()[1], 16 / z_dim)
self.assertEqual(local_shard.data.size()[2], 32)

# partial replication requires >1 devices; otherwise, it's replicated.
if self.n_devices > 1:
# xt1 is sharded `z_dim`-way, replicated `n_devices/z_dim`-way.
self.assertTrue('last_tile_dim_replicate' in
torch_xla._XLAC._get_xla_sharding_spec(t1))
self.assertTrue('[%d,%d,1,%d]' %
(1, z_dim, self.n_devices //
z_dim) in torch_xla._XLAC._get_xla_sharding_spec(t1))
actual = (xt1 + t2).cpu()
self.assertTrue(torch.allclose(expected, actual))

def test_mark_sharding_not_ordered_partial_4d(self):
device = xm.xla_device()
t1 = torch.randn(8, 16, 32, 64).to(device)
t2 = torch.randn(8, 16, 32, 64).to(device)
# Somehow the eager cpu result is different from the xla result.
expected = t1 + t2
# To re-materialize t1 and t2.
xm.mark_step()
xm.wait_device_ops()
expected = expected.cpu()

# Shard along two axes if four or more devices are available
z_dim = 2 if self.n_devices >= 4 else 1
mesh = self._get_mesh((z_dim, 1, 1, self.n_devices // z_dim))

# Expect local shard size to be [8, 16, 32 / z_dim, 64]
xt1 = xs.mark_sharding(t1, mesh, (2, None, 0, None))

for local_shard in xt1.local_shards:
self.assertEqual(local_shard.data.size()[0], 8)
self.assertEqual(local_shard.data.size()[1], 16)
self.assertEqual(local_shard.data.size()[2], 32 / z_dim)
self.assertEqual(local_shard.data.size()[3], 64)

# partial replication requires >1 devices; otherwise, it's replicated.
if self.n_devices > 1:
# xt1 is sharded `z_dim`-way, replicated `n_devices/z_dim`-way.
self.assertTrue('last_tile_dim_replicate' in
torch_xla._XLAC._get_xla_sharding_spec(t1))
self.assertTrue('[1,1,%d,1,%d]' %
(z_dim,
(self.n_devices //
z_dim)) in torch_xla._XLAC._get_xla_sharding_spec(t1))
actual = (xt1 + t2).cpu()
self.assertTrue(torch.allclose(expected, actual))

def test_mark_sharding_not_ordered_2d_tensor_3d_mesh(self):
ct1 = torch.randn(16, 16, device='cpu')
ct2 = torch.randn(16, 16, device='cpu')
expected = ct1 + ct2

t1 = ct1.to(xm.xla_device())
t2 = ct2.to(xm.xla_device())
mesh = self._get_mesh((1, self.n_devices, 1))
# sharding spec here is not ordered.
xt1 = xs.mark_sharding(t1, mesh, partition_spec=(2, 1))
if self.n_devices > 1:
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt1.global_tensor])
sharding_annotation = 'sharding={devices=[1,1,%d]%s}' % (
self.n_devices, ','.join(
[str(d) for d in mesh.get_logical_mesh().flatten()]))
self.assertIn(sharding_annotation, hlo)
actual = (xt1 + t2).cpu()
self.assertTrue(torch.allclose(expected, actual))

def test_partial_replication_addmm(self):
device = xm.xla_device()
z_dim = 2 if self.n_devices >= 4 else 1
Expand Down
26 changes: 12 additions & 14 deletions torch_xla/experimental/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,6 @@ def _get_tile_assignment(mesh: Mesh,
partition_spec: Tuple[Union[int, None]]) -> List[int]:
# Use Torch.tensor here to make use of the torch.transpose_
mesh_list_tensor = torch.tensor(mesh.get_logical_mesh().tolist())
# This is partial sharding case, tile_assigniment will be ignore in favor of
# group_assignment and replication_groups.
if (mesh_list_tensor.dim() != len(partition_spec)):
return mesh_list_tensor.tolist()
partition_spec_list = list(partition_spec)
for i in range(len(partition_spec_list)):
if partition_spec_list[i] == None:
Expand All @@ -339,26 +335,28 @@ def _get_tile_assignment(mesh: Mesh,
return mesh_list_tensor.permute(partition_spec_list).tolist()


def _get_group_assignment(
sharding_type: ShardingType, mesh: Mesh,
partition_spec: Tuple[Union[int, None]]) -> Tuple[List, List]:
def _get_group_assignment(sharding_type: ShardingType, mesh: Mesh,
partition_spec: Tuple[Union[int, None]],
tile_assignment: List) -> Tuple[List, List]:
group_assignment = list()
replication_groups = list()
# TODO(JackCaoG): 3d mesh on 2d tensor
mesh_shape_list = list(torch.tensor(tile_assignment).size())
if sharding_type is ShardingType.PARTIAL:
# Shard across groups and replicate within subgroups; replicated dims
# will be used to group replication devices.
tile_dims = [d for d in partition_spec if d is not None]
replicated_dims = set(range(len(mesh.mesh_shape))) - set(tile_dims)
replicated_dims = set(range(len(mesh_shape_list))) - set(tile_dims)

group_list = [np.array(mesh.get_logical_mesh().tolist())]
group_list = [np.array(tile_assignment)]
for d in tile_dims:
_group_list = list()
for group_members in group_list:
_group_list += np.split(group_members, mesh.mesh_shape[d], d)
_group_list += np.split(group_members, mesh_shape_list[d], d)
group_list = _group_list
replication_groups = [group.flatten().tolist() for group in group_list]

group_tile_shape = list(mesh.mesh_shape)
group_tile_shape = mesh_shape_list
for d in replicated_dims:
group_tile_shape[d] = 1
group_assignment = np.arange(len(replication_groups)).reshape(
Expand Down Expand Up @@ -415,7 +413,6 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
assert len(specs) == len(np.unique(specs)), \
f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."

tile_assignment = _get_tile_assignment(mesh, partition_spec)
# check for sharding 2D tensor on a 3D mesh
original_shape = tuple(t.shape)
# number of dims to expand on tensor
Expand All @@ -426,9 +423,10 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
shape = (1,) * tensor_expand + (*original_shape,)
t = t.expand(shape)

tile_assignment = _get_tile_assignment(mesh, partition_spec)
sharding_type = _get_sharding_type(partition_spec, num_devices)
group_assignment, replication_groups = _get_group_assignment(
sharding_type, mesh, partition_spec)
sharding_type, mesh, partition_spec, tile_assignment)

def tensor_squeeze(t, tensor_expand):
if tensor_expand:
Expand Down Expand Up @@ -484,7 +482,7 @@ def __post_init__(self):
self._sharding_type = _get_sharding_type(partition_spec,
xr.global_device_count())
self._group_assignment, self._replication_groups = _get_group_assignment(
self._sharding_type, mesh, partition_spec)
self._sharding_type, mesh, partition_spec, self._tile_assignment)

def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]:
"""
Expand Down

0 comments on commit aac03da

Please sign in to comment.