Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[dist][sharded_tensor] Fix ChunkShardingSpec metadata offsets for emp…
…ty shards (#121002) ChunkShardingSpec generated metadata where offsets exceed the tensor size. Example: Torchrec prepared ShardedTensorMetadata: ``` ShardedTensorMetadata(shards_metadata=[ ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2, 512], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2, 0], shard_sizes=[2, 512], placement=rank:1/cuda:1), ShardMetadata(shard_offsets=[4, 0], shard_sizes=[2, 512], placement=rank:2/cuda:2), ShardMetadata(shard_offsets=[6, 0], shard_sizes=[2, 512], placement=rank:3/cuda:3), ShardMetadata(shard_offsets=[8, 0], shard_sizes=[2, 512], placement=rank:4/cuda:4), ShardMetadata(shard_offsets=[10, 0], shard_sizes=[0, 512], placement=rank:5/cuda:5), ShardMetadata(shard_offsets=[10, 0], shard_sizes=[0, 512], placement=rank:6/cuda:6) ], size=torch.Size([10, 512] ), ``` Calling ShardedTensor._init_from_local_shards_and_global_metadata() ShardedTensor ShardingSpec builds metadata ``` ShardedTensorMetadata(shards_metadata=[ ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2, 512], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2, 0], shard_sizes=[2, 512], placement=rank:1/cuda:1), ShardMetadata(shard_offsets=[4, 0], shard_sizes=[2, 512], placement=rank:2/cuda:2), ShardMetadata(shard_offsets=[6, 0], shard_sizes=[2, 512], placement=rank:3/cuda:3), ShardMetadata(shard_offsets=[8, 0], shard_sizes=[2, 512], placement=rank:4/cuda:4), ShardMetadata(shard_offsets=[10, 0], shard_sizes=[0, 512], placement=rank:5/cuda:5), ShardMetadata(shard_offsets=[12, 0], shard_sizes=[0, 512], placement=rank:6/cuda:6) ], size=torch.Size([10, 512]), tensor_properties=TensorProperties(dtype=torch.float16, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False)) ``` The deduced ChunkShardingSpec: ``` ChunkShardingSpec(dim=0, placements=[rank:0/cuda:0, rank:1/cuda:1, rank:2/cuda:2, rank:3/cuda:3, rank:4/cuda:4, rank:5/cuda:5, rank:6/cuda:6]) ``` The fix is to limit offsets by dim size. Differential Revision: [D54419513](https://our.internmc.facebook.com/intern/diff/D54419513) Pull Request resolved: #121002 Approved by: https://github.com/wz337
- Loading branch information