Skip to content

Commit

Permalink
[dist][sharded_tensor] Fix ChunkShardingSpec metadata offsets for emp…
Browse files Browse the repository at this point in the history
…ty shards (#121002)

ChunkShardingSpec generated metadata where offsets exceed the tensor size.

Example:

Torchrec prepared ShardedTensorMetadata:
```
ShardedTensorMetadata(shards_metadata=[
ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2, 512], placement=rank:0/cuda:0),
ShardMetadata(shard_offsets=[2, 0], shard_sizes=[2, 512], placement=rank:1/cuda:1),
ShardMetadata(shard_offsets=[4, 0], shard_sizes=[2, 512], placement=rank:2/cuda:2),
ShardMetadata(shard_offsets=[6, 0], shard_sizes=[2, 512], placement=rank:3/cuda:3),
ShardMetadata(shard_offsets=[8, 0], shard_sizes=[2, 512], placement=rank:4/cuda:4),
ShardMetadata(shard_offsets=[10, 0], shard_sizes=[0, 512], placement=rank:5/cuda:5),
ShardMetadata(shard_offsets=[10, 0], shard_sizes=[0, 512], placement=rank:6/cuda:6)
],
size=torch.Size([10, 512]
),
```
Calling ShardedTensor._init_from_local_shards_and_global_metadata()
ShardedTensor ShardingSpec builds metadata

```
ShardedTensorMetadata(shards_metadata=[
ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2, 512], placement=rank:0/cuda:0),
ShardMetadata(shard_offsets=[2, 0], shard_sizes=[2, 512], placement=rank:1/cuda:1),
ShardMetadata(shard_offsets=[4, 0], shard_sizes=[2, 512], placement=rank:2/cuda:2),
ShardMetadata(shard_offsets=[6, 0], shard_sizes=[2, 512], placement=rank:3/cuda:3),
ShardMetadata(shard_offsets=[8, 0], shard_sizes=[2, 512], placement=rank:4/cuda:4),
ShardMetadata(shard_offsets=[10, 0], shard_sizes=[0, 512], placement=rank:5/cuda:5),
ShardMetadata(shard_offsets=[12, 0], shard_sizes=[0, 512], placement=rank:6/cuda:6)
],
size=torch.Size([10, 512]), tensor_properties=TensorProperties(dtype=torch.float16, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))
```
The deduced ChunkShardingSpec:
```
ChunkShardingSpec(dim=0, placements=[rank:0/cuda:0, rank:1/cuda:1, rank:2/cuda:2, rank:3/cuda:3, rank:4/cuda:4, rank:5/cuda:5, rank:6/cuda:6])
```

The fix is to limit offsets by dim size.

Differential Revision: [D54419513](https://our.internmc.facebook.com/intern/diff/D54419513)
Pull Request resolved: #121002
Approved by: https://github.com/wz337
  • Loading branch information
IvanKobzarev authored and pytorchmergebot committed Mar 2, 2024
1 parent 66b20b4 commit bab4b5a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
5 changes: 5 additions & 0 deletions test/distributed/_shard/sharded_tensor/test_sharded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,11 @@ def test_shard_tensor_with_empty_shard(self):

# Verify.
self.assertTrue(isinstance(st, sharded_tensor.ShardedTensor))
sms = st.metadata().shards_metadata
self.assertEqual(len(sms), 4)
for sm in sms:
self.assertTrue(sm.shard_offsets[0] + sm.shard_sizes[0] <= tensor.size(0))

local_shard = st.local_tensor()
self.assertEqual(1, len(st.local_shards()))
if dist.get_rank() < 3:
Expand Down
11 changes: 9 additions & 2 deletions torch/distributed/_shard/sharding_spec/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,14 @@ def _infer_sharding_spec_from_shards_metadata(shards_metadata):
chunk_sharding_dim = None
chunk_offset_list = []
shard_size_list = []
shard_offset_list = []
# collect local shard metadatas from the global sharded_tensor_metadata
for shard_metadata in shards_metadata: # type: ignore[attr-defined]
placements.append(shard_metadata.placement)
local_offsets = shard_metadata.shard_offsets
chunk_offset_list.append(sum(local_offsets))
shard_size_list.append(shard_metadata.shard_sizes)
shard_offset_list.append(shard_metadata.shard_offsets)
shard_dims = [idx for idx, e in enumerate(local_offsets) if e != 0]
# If the offset is [0, 0, ..., 0] (all zeros),
# we cannot decide whether how the tensor is sharded.
Expand Down Expand Up @@ -220,16 +222,21 @@ def _infer_sharding_spec_from_shards_metadata(shards_metadata):
dim=chunk_sharding_dim,
placements=placements,
)

shard_sizes = sorted([x[chunk_sharding_dim] for x in shard_size_list])
shard_total_length = sum(shard_sizes)
shard_offsets = sorted([x[chunk_sharding_dim] for x in shard_offset_list])

chunks = len(placements)
split_size = get_split_size(shard_total_length, chunks)
chunk_shard_sizes = sorted(
[
get_chunked_dim_size(shard_total_length, split_size, idx)
for idx in range(len(placements))
for idx in range(chunks)
]
)
if shard_sizes == chunk_shard_sizes:
# Should match ChunkShardingSpec offsets calculation
chunk_shard_offsets = [split_size * idx for idx in range(chunks)]
if shard_sizes == chunk_shard_sizes and shard_offsets == chunk_shard_offsets:
return chunk_spec
return EnumerableShardingSpec(shards_metadata)

0 comments on commit bab4b5a

Please sign in to comment.