Skip to content

Commit

Permalink
[shard] add back init_from_local_shard_and_global_metadata API (pytor…
Browse files Browse the repository at this point in the history
…ch#69226)

Summary:
Pull Request resolved: pytorch#69226

This add back the previous init_from_local_shards API, but renamed it to init_from_local_shard_and_global_metadata. It's a partial revert of D32147888 (pytorch@35712a8). We now provide two APIs:
1. `init_from_local_shards`: user don't need to provide global metadata and we do all_gather under the hood, the other that
2. `init_from_local_shards_and_global_metadata`: user need to explicitly construct ShardedTensorMetadata to use this API, need to ensure correctness on all ranks, as there's no cross-rank communication/validations.

All of these two APIs stay private until it stablizes and proof of UX. The second one can only be called on `ShardedTensor` class directly, not included as a package API for now.

Test Plan:
test_init_from_local_shards_and_global_metadata
test_init_from_local_shards_and_global_metadata_invalid_shards

Reviewed By: dstaay-fb, pritamdamania87

Differential Revision: D32746882

fbshipit-source-id: bafd26ce16c02e2095907f9e59984a5d775c7df5
  • Loading branch information
wanchaol authored and pull[bot] committed Feb 15, 2023
1 parent c620635 commit 485b787
Show file tree
Hide file tree
Showing 2 changed files with 282 additions and 0 deletions.
182 changes: 182 additions & 0 deletions test/distributed/_sharded_tensor/test_sharded_tensor.py
Expand Up @@ -22,6 +22,7 @@
)
from torch.distributed._sharded_tensor.api import (
CreateOp,
ShardedTensor,
TensorInitParams,
TensorProperties,
_create_tensor_from_params,
Expand Down Expand Up @@ -1763,6 +1764,76 @@ def test_init_from_local_shards(self):
shard = remote_shard.to_here()
self.assertEqual((5, 5), shard.tensor.size())


@with_comms
@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_init_from_local_shards_and_global_metadata(self):
local_shard_metadata = ShardMetadata(
shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5],
shard_sizes=[5, 5],
placement=f"rank:{self.rank}/cuda:{self.rank}"
)

shards_metadata = []
for r in range(self.world_size):
if r == self.rank:
shards_metadata.append(local_shard_metadata)
else:
shards_metadata.append(ShardMetadata(
shard_offsets=[(r // 2) * 5, (r % 2) * 5],
shard_sizes=[5, 5],
placement=f"rank:{r}/cuda:{r}"
))

local_shards = [_sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata)]

sharded_tensor_metadata = _sharded_tensor.ShardedTensorMetadata(
shards_metadata=shards_metadata,
size=torch.Size([10, 10]),
dtype=torch.get_default_dtype(),
layout=torch.strided,
requires_grad=False,
memory_format=torch.contiguous_format,
pin_memory=False,
)

sharded_tensor = ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards,
sharded_tensor_metadata
)
self.assertEqual((10, 10), sharded_tensor.size())
self.assertEqual(1, len(sharded_tensor.local_shards()))

# Verify local shard.
local_shard = sharded_tensor.local_shards()[0]
self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device)
self.assertEqual((5, 5), local_shard.tensor.size())

# Verify local shard metadata.
self.assertEqual((self.rank // 2 * 5, (self.rank % 2) * 5), local_shard.metadata.shard_offsets)
self.assertEqual((5, 5), local_shard.metadata.shard_sizes)
self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}', local_shard.metadata.placement)

# Verify global metadata.
shards_metadata = sharded_tensor.metadata().shards_metadata
self.assertEqual(4, len(shards_metadata))
for rank, shard_metadata in enumerate(shards_metadata):
self.assertEqual((rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets)
self.assertEqual((5, 5), shard_metadata.shard_sizes)
self.assertEqual(f'rank:{rank}/cuda:{rank}', shard_metadata.placement)

# Validate remote shards.
remote_shards = sharded_tensor.remote_shards
self.assertEqual(3, len(remote_shards))

for rpc_rank, shards in remote_shards.items():
self.assertEqual(1, len(shards))
for remote_shard in shards:
self.assertEqual(rpc_rank, remote_shard.owner().id)
shard = remote_shard.to_here()
self.assertEqual((5, 5), shard.tensor.size())

@with_comms
@skip_if_lt_x_gpu(4)
@requires_nccl()
Expand Down Expand Up @@ -1962,5 +2033,116 @@ def test_init_from_local_shards_invalid_shards_gaps(self):
with self.assertRaisesRegex(ValueError, "does not match tensor volume"):
sharded_tensor = _sharded_tensor.init_from_local_shards(local_shards, [10, 10], init_rrefs=True)

@with_comms
@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_init_from_local_shards_and_global_metadata_invalid_shards(self):
local_shard_metadata = ShardMetadata(
shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5],
shard_sizes=[5, 5],
placement=f"rank:{self.rank}/cuda:{self.rank}"
)

shards_metadata = []
for r in range(self.world_size):
if r == self.rank:
shards_metadata.append(local_shard_metadata)
else:
shards_metadata.append(ShardMetadata(
shard_offsets=[(r // 2) * 5, (r % 2) * 5],
shard_sizes=[5, 5],
placement=f"rank:{r}/cuda:{r}"
))

sharded_tensor_metadata = _sharded_tensor.ShardedTensorMetadata(
shards_metadata=shards_metadata,
size=torch.Size([10, 10]),
dtype=torch.get_default_dtype(),
layout=torch.strided,
requires_grad=False,
memory_format=torch.contiguous_format,
pin_memory=False,
)

empty_local_shards = []
with self.assertRaisesRegex(RuntimeError, 'does not match number of local shards metadata'):
sharded_tensor = ShardedTensor._init_from_local_shards_and_global_metadata(
empty_local_shards,
sharded_tensor_metadata
)

wrong_num_shards = [
_sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata),
_sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata)
]
with self.assertRaisesRegex(RuntimeError, 'does not match number of local shards metadata'):
sharded_tensor = ShardedTensor._init_from_local_shards_and_global_metadata(
wrong_num_shards,
sharded_tensor_metadata
)

wrong_size_shards = [_sharded_tensor.Shard(torch.randn(2, 3, device=f"cuda:{self.rank}"), local_shard_metadata)]
with self.assertRaisesRegex(ValueError, 'Local shard tensor is incompatible with'):
sharded_tensor = ShardedTensor._init_from_local_shards_and_global_metadata(
wrong_size_shards,
sharded_tensor_metadata
)

wrong_device_shards = [_sharded_tensor.Shard(torch.randn(5, 5), local_shard_metadata)]
with self.assertRaisesRegex(ValueError, 'Local shard tensor device does not match with local Shard placement'):
sharded_tensor = ShardedTensor._init_from_local_shards_and_global_metadata(
wrong_device_shards,
sharded_tensor_metadata
)

wrong_dtype_shards = [
_sharded_tensor.Shard(torch.ones(5, 5, device=f"cuda:{self.rank}", dtype=torch.int), local_shard_metadata)
]
with self.assertRaisesRegex(ValueError, 'Local shard tensor dtype does not match with sharded_tensor_metadata'):
sharded_tensor = ShardedTensor._init_from_local_shards_and_global_metadata(
wrong_dtype_shards,
sharded_tensor_metadata
)

indices = [[0, 1, 1], [2, 0, 2]]
values = [3.2, 4.5, 5.8]
sparse_tensor = torch.sparse_coo_tensor(indices, values, (5, 5), device=f"cuda:{self.rank}")

wrong_layout_shards = [
_sharded_tensor.Shard(sparse_tensor, local_shard_metadata)
]
with self.assertRaisesRegex(ValueError, 'Local shard tensor layout does not match with sharded_tensor_metadata'):
sharded_tensor = ShardedTensor._init_from_local_shards_and_global_metadata(
wrong_layout_shards,
sharded_tensor_metadata
)

wrong_requires_grad_shards = [
_sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}", requires_grad=True), local_shard_metadata)
]
with self.assertRaisesRegex(ValueError, 'Local shard tensor requires_grad does not match with sharded_tensor_metadata'):
sharded_tensor = ShardedTensor._init_from_local_shards_and_global_metadata(
wrong_requires_grad_shards,
sharded_tensor_metadata
)

wrong_pin_memory_shards = [
_sharded_tensor.Shard(torch.randn(5, 5, pin_memory=True), local_shard_metadata)
]
with self.assertRaisesRegex(ValueError, 'Local shard tensor pin_memory does not match with sharded_tensor_metadata'):
sharded_tensor = ShardedTensor._init_from_local_shards_and_global_metadata(
wrong_pin_memory_shards,
sharded_tensor_metadata
)

wrong_memory_format_shards = [
_sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}").t(), local_shard_metadata)
]
with self.assertRaisesRegex(ValueError, 'Only torch.contiguous_format memory_format is currently supported'):
sharded_tensor = ShardedTensor._init_from_local_shards_and_global_metadata(
wrong_memory_format_shards,
sharded_tensor_metadata
)

if __name__ == '__main__':
run_tests()
100 changes: 100 additions & 0 deletions torch/distributed/_sharded_tensor/api.py
Expand Up @@ -22,6 +22,7 @@
check_tensor,
get_split_size,
get_chunked_dim_size,
validate_non_overlapping_shards_metadata,
)
from torch.types import Number
from .metadata import TensorProperties, ShardedTensorMetadata
Expand Down Expand Up @@ -369,6 +370,105 @@ def _init_from_local_shards(
sharded_tensor._post_init()
return sharded_tensor

@classmethod
def _init_from_local_shards_and_global_metadata(
cls,
local_shards: List[Shard],
sharded_tensor_metadata: ShardedTensorMetadata,
process_group=None,
init_rrefs=False,
):
"""
Initialize a ShardedTensor with local shards and a global
ShardedTensorMetadata built on each rank.
Warning: This API is experimental and subject to change. It does
not do cross rank validations, and fully rely on the user
for the correctness of sharded_tensor_metadata on each rank
"""
process_group = (
process_group
if process_group is not None
else distributed_c10d._get_default_group()
)
current_rank = dist.get_rank(process_group)

shards_metadata = sharded_tensor_metadata.shards_metadata
tensor_properties = sharded_tensor_metadata.tensor_properties

if len(shards_metadata) == 0:
raise ValueError("shards_metadata must not be empty!")

if tensor_properties.layout != torch.strided:
raise ValueError('Only torch.strided layout is currently supported')

sharded_tensor = cls.__new__(cls)
sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs)

sharded_tensor._metadata = sharded_tensor_metadata

local_shard_metadatas = []

def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False):
tensor_property_or_metadata = "tensor property" if is_property else "local ShardMetadata"
if expected != actual:
raise ValueError(f"Local shards' tensor {prop_name} property is incompatible with "
f"{tensor_property_or_metadata} on rank {rank}: "
f"{tensor_property_or_metadata} {prop_name}={expected}, "
f"local shard tensor {prop_name}={actual}.")

# collect local shard metadatas from the global sharded_tensor_metadata
for shard_metadata in shards_metadata: # type: ignore[attr-defined]
rank, local_device = _parse_and_validate_remote_device(sharded_tensor._process_group, shard_metadata.placement)

if current_rank == rank:
local_shard_metadatas.append(shard_metadata)

if len(local_shards) != len(local_shard_metadatas):
raise RuntimeError(
f'Number of local shards ({len(local_shards)}) does not match number of local '
f'shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) '
f'on rank ({current_rank}) '
)

for shard in local_shards:
shard_meta = shard.metadata
local_shard_tensor = shard.tensor
rank, local_device = _parse_and_validate_remote_device(sharded_tensor._process_group, shard_meta.placement)

# validate if shard_meta in the metadatas collected from sharded_tensor_metadata
assert shard_meta in local_shard_metadatas, \
"local shard metadata not in sharded_tensor_metadata!"

if not local_shard_tensor.is_contiguous():
raise ValueError('Only torch.contiguous_format memory_format is currently supported')

_raise_if_mismatch(tensor_properties.layout, local_shard_tensor.layout, "layout", current_rank, True)
_raise_if_mismatch(shard_meta.shard_sizes, list(local_shard_tensor.size()), "size", current_rank)
_raise_if_mismatch(tensor_properties.pin_memory, local_shard_tensor.is_pinned(), "pin_memory", current_rank, True)
_raise_if_mismatch(local_device, local_shard_tensor.device, "device", current_rank)
_raise_if_mismatch(tensor_properties.dtype, local_shard_tensor.dtype, "dtype", current_rank, True)
_raise_if_mismatch(
tensor_properties.requires_grad, local_shard_tensor.requires_grad, "requires_grad", current_rank, True)

# check if shards_metadata have overlap shards
validate_non_overlapping_shards_metadata(shards_metadata)

# check if the shards_metadata is compatible with overall size of the sharded tensor.
check_tensor(shards_metadata, list(sharded_tensor_metadata.size))

# done validation, add local_shards
sharded_tensor._local_shards = local_shards
# make a EnumerableShardingSpec for sharded tensors that initialized from this API.
# TODO: make sharding spec a ChunkShardingSpec by inferring from the metadata list.
# see issue https://github.com/pytorch/pytorch/issues/67244
sharded_tensor._sharding_spec = EnumerableShardingSpec(shards_metadata)

# run post initialization, i.e. map registration, rpc initialization
sharded_tensor._post_init()
return sharded_tensor


def _init_chunked(self, dims, tensor_init_params: TensorInitParams, ):
current_rank = dist.get_rank(self._process_group)
sharding_dim = self._sharding_spec.dim # type: ignore[attr-defined]
Expand Down

0 comments on commit 485b787

Please sign in to comment.