-
Notifications
You must be signed in to change notification settings - Fork 559
[SPMD] Mesh to support custom device order. #4162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,38 +11,60 @@ | |
| import torch_xla.utils.utils as xu | ||
| import torch_xla.experimental.xla_sharding as xs | ||
| from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor | ||
| from torch_xla.experimental.pjrt import using_pjrt | ||
|
|
||
|
|
||
| @unittest.skipIf((os.getenv('PJRT_DEVICE') == "") or ( | ||
| xm.get_xla_supported_devices("GPU") is not None | ||
| ), "PyTorch/XLA SPMD requires PJRT_DEVICE={CPU, TPU}, GPU is currently not supported." | ||
| ) | ||
| @unittest.skipIf(not using_pjrt() or xm.get_xla_supported_devices("GPU"), | ||
| f"Requires PJRT_DEVICE set to `TPU` or `CPU`.") | ||
| class XlaShardingTest(unittest.TestCase): | ||
|
|
||
| n_devices = 0 | ||
| device_ids = None | ||
|
|
||
| @classmethod | ||
| def setUpClass(cls): | ||
| cls.n_devices = len(xm.get_xla_supported_devices()) | ||
| cls.device_ids = np.array(range(cls.n_devices)) | ||
|
|
||
| def _get_mesh(self, mesh_shape, device_ids=None): | ||
| if device_ids is None: | ||
| device_ids = self.device_ids | ||
| assert len(device_ids) == self.n_devices | ||
| return xs.Mesh(device_ids, mesh_shape) | ||
|
|
||
| def test_xla_sharded_tensor(self): | ||
| n_devices = xm.xrt_world_size() | ||
| mesh_shape = (1, n_devices) | ||
| partition_spec = (0, 1) | ||
| xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], | ||
| dtype=torch.float, | ||
| device=xm.xla_device()) | ||
| xst1 = xs.mark_sharding(xt1, mesh_shape, partition_spec) | ||
| xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), | ||
| partition_spec) | ||
|
|
||
| # TODO(244003536) add more tests for XLAShardedTensror. | ||
| self.assertTrue(isinstance(xst1, XLAShardedTensor)) | ||
|
|
||
| def test_custom_tile_assignment(self): | ||
| xt = torch.randn(10, 20).to(device=xm.xla_device()) | ||
| mesh_shape = (1, self.n_devices) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see the tests have all devices mapped to a single axis - is there anything stopping us from using e.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, but for the unit testing a flat mesh is easier to work with since we don't know how many devices we would have (e.g., for CPU, we will have 1). |
||
| device_ids = np.flip(self.device_ids) | ||
| mesh = self._get_mesh(mesh_shape, device_ids) | ||
| xs.mark_sharding(xt, mesh, (0, 1)) | ||
| annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join([ | ||
| str(i) for i in reversed(range(self.n_devices)) | ||
| ])) if self.n_devices > 1 else '{maximal device=0}' | ||
| self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt)) | ||
|
|
||
| def test_mark_sharding_2d(self): | ||
| t1 = torch.randn(1, 128, device='cpu') | ||
| t2 = torch.randn(1, 128, device='cpu') | ||
| expected = t1 @ t2.T | ||
|
|
||
| xt1 = t1.to(xm.xla_device()) | ||
| xt2 = t2.to(xm.xla_device()) | ||
| n_devices = xm.xrt_world_size() | ||
| xs.mark_sharding(xt1, (1, n_devices), (0, 1)) | ||
| annotation = '{devices=[1,%d]%s}' % (n_devices, ','.join( | ||
| [str(i) | ||
| for i in range(n_devices)])) if n_devices > 1 else '{maximal device=0}' | ||
| xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), (0, 1)) | ||
| annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join([ | ||
| str(i) for i in range(self.n_devices) | ||
| ])) if self.n_devices > 1 else '{maximal device=0}' | ||
| self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt1)) | ||
|
|
||
| actual = (xt1 @ xt2.T).cpu() | ||
|
|
@@ -53,28 +75,28 @@ def test_mark_sharding_4d(self): | |
| expected = t + t | ||
|
|
||
| xt = t.to(xm.xla_device()) | ||
| n_devices = xm.xrt_world_size() | ||
| xs.mark_sharding(xt, (1, 1, 1, n_devices), (0, 1, 2, 3)) | ||
| annotation = '{devices=[1,1,1,%d]%s}' % (n_devices, ','.join( | ||
| [str(i) | ||
| for i in range(n_devices)])) if n_devices > 1 else '{maximal device=0}' | ||
| xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)), | ||
| (0, 1, 2, 3)) | ||
| annotation = '{devices=[1,1,1,%d]%s}' % (self.n_devices, ','.join([ | ||
| str(i) for i in range(self.n_devices) | ||
| ])) if self.n_devices > 1 else '{maximal device=0}' | ||
| self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt)) | ||
|
|
||
| actual = (xt + xt).cpu() | ||
| self.assertTrue(torch.allclose(expected, actual)) | ||
|
|
||
| def test_clear_sharding(self): | ||
| xt = torch.randn(2, 4, 8, 16).to(xm.xla_device()) | ||
| n_devices = xm.xrt_world_size() | ||
| xs.mark_sharding(xt, (1, 1, 1, n_devices), (0, 1, 2, 3)) | ||
| xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)), | ||
| (0, 1, 2, 3)) | ||
| self.assertTrue(torch_xla._XLAC._get_xla_sharding_spec(xt)) | ||
| xs.clear_sharding(xt) | ||
| self.assertFalse(torch_xla._XLAC._get_xla_sharding_spec(xt)) | ||
|
|
||
| def test_deep_copy(self): | ||
| xt = torch.randn(2, 4, 8, 16).to(xm.xla_device()) | ||
| n_devices = xm.xrt_world_size() | ||
| xs.mark_sharding(xt, (1, 1, 1, n_devices), (0, 1, 2, 3)) | ||
| xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)), | ||
| (0, 1, 2, 3)) | ||
| xt2 = copy.deepcopy(xt) | ||
| self.assertEqual( | ||
| torch_xla._XLAC._get_xla_sharding_spec(xt), | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,49 +1,110 @@ | ||
| from collections import OrderedDict | ||
| import torch | ||
| import torch_xla | ||
| import torch_xla.core.xla_model as xm | ||
| from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor | ||
| from torch_xla.experimental.pjrt import requires_pjrt | ||
|
|
||
| import numpy as np | ||
| from typing import Tuple, Union | ||
| from typing import Tuple, Union, List | ||
|
|
||
|
|
||
| def mark_sharding(t: Union[torch.Tensor, | ||
| XLAShardedTensor], mesh_shape: Tuple[int], | ||
| class Mesh: | ||
| """Describe the logical XLA device topology mesh and the underlying resources. | ||
|
|
||
| Args: | ||
| device_ids (Union[np.ndarray, List]): A raveled list of devices (IDs) in a custom order. The list is reshaped | ||
| to an `mesh_shape` array, filling the elements using C-like index order. | ||
|
|
||
| mesh_shape (Tuple[int, ...]): A int tuple describing the logical topology shape | ||
| of the device mesh, and each element describes the number of devices in | ||
| the corresponding axis. | ||
|
|
||
| axis_names (Tuple[str, ...]): A sequence of resource axis names to be assigned to the dimensions | ||
| of the `devices` argument. Its length should match the rank of `devices`. | ||
|
|
||
| Example: | ||
| —------------------------------ | ||
| mesh_shape = (4, 2) | ||
| num_devices = len(xm.get_xla_supported_devices()) | ||
| device_ids = np.array(range(num_devices)) | ||
| mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) | ||
| mesh.get_logical_mesh() | ||
| >> array([[0, 1], | ||
| [2, 3], | ||
| [4, 5], | ||
| [6, 7]]) | ||
| mesh.shape() | ||
| >> OrderedDict([('x', 4), ('y', 2)]) | ||
| """ | ||
|
|
||
| device_ids: np.ndarray | ||
| mesh_shape: Tuple[int, ...] | ||
| axis_names: Tuple[str, ...] | ||
|
|
||
| def __init__(self, | ||
| device_ids: Union[np.ndarray, List], | ||
| mesh_shape: Tuple[int, ...], | ||
| axis_names: Tuple[str, ...] = None): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just curious - how will There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question, mesh axis annotation is useful since it makes the annotation logic more readable. We can also build a partitioning rule based on the axis name, instead of int indices. |
||
| if not isinstance(device_ids, np.ndarray): | ||
| device_ids = np.array(device_ids) | ||
| assert (axis_names is None) or (len(mesh_shape) == len(axis_names)) | ||
| assert (len(device_ids) == np.prod(mesh_shape)) | ||
| assert len(device_ids) == len(np.unique(device_ids)) | ||
| self.device_ids = device_ids | ||
| self.mesh_shape = mesh_shape | ||
| self.axis_names = axis_names | ||
| assert all(d < self.size() for d in device_ids) | ||
|
|
||
| def size(self): | ||
| return np.prod(self.mesh_shape) | ||
|
|
||
| def shape(self): | ||
| return OrderedDict( | ||
| (name, size) for name, size in zip(self.axis_name, self.mesh_shape)) | ||
|
|
||
| def get_logical_mesh(self): | ||
| return self.device_ids.reshape(self.mesh_shape) | ||
|
|
||
|
|
||
| @requires_pjrt | ||
| def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, | ||
| partition_spec: Tuple[Union[int, None]]) -> XLAShardedTensor: | ||
| """ | ||
| Annotates the tensor provided with XLA partition spec. Internally, | ||
| it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass. | ||
| Args: | ||
| t (Union[torch.Tensor, XLAShardedTensor]): input tensor to be annotated with partition_sepc. | ||
| mesh_shape (Tuple[Union[int, None]]): A int tuple describing the logical topology | ||
| of the device mesh, and each element describes the number of devices in | ||
| the corresponding axis. | ||
|
|
||
| mesh (Mesh): describes the logical XLA device topology and the underlying device IDs. | ||
|
|
||
| partition_spec (Tuple[int, None]): A tuple of device_mesh dimension index or `None`. | ||
| This specifies how each input rank is sharded (index to mesh_shape) or replicated (None). | ||
| For example, we can shard an 8x10 tensor 4-way row-wise, and replicate column-wise. | ||
| >> input = torch.randn(8, 10) | ||
| >> mesh_shape = (4, 2) | ||
| >> assert np.prod(mesh_shape) == xm.xrt_world_size() | ||
| >> partition_spec = (0, None) | ||
| >> assert len(input.shape) == len(partition_spec) | ||
|
|
||
| Examples | ||
| —------------------------------ | ||
| mesh_shape = (4, 2) | ||
| input = torch.randn(8, 32).to(xm.xla_device()) | ||
| num_devices = len(xm.get_xla_supported_devices()) | ||
| device_ids = np.array(range(num_devices)) | ||
| mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) | ||
|
|
||
| # 4-way data parallel | ||
| input = xs.mark_sharding(input, mesh_shape, (0, None)) | ||
| linear = nn.Linear(32, 10).to(xm.xla_device()) | ||
| input = torch.randn(8, 32).to(xm.xla_device()) | ||
| xs.mark_sharding(input, mesh, (0, None)) | ||
|
|
||
| # 2-way model parallel | ||
| linear.weight = xs.mark_sharding(linear.weight, device_mesh, (None, 1)) | ||
| output = linear(input) | ||
| # full replication | ||
| output = xs.mark_sharding(output, device_mesh, (None, None)) | ||
| linear = nn.Linear(32, 10).to(xm.xla_device()) | ||
| xs.mark_sharding(linear.weight, mesh, (None, 1)) | ||
| """ | ||
| num_devices = len(xm.get_xla_supported_devices()) | ||
| assert num_devices > 0, "This requires XLA supported device(s)." | ||
| assert np.prod(mesh_shape) == num_devices, \ | ||
| f"{mesh_shape} is not mappable over {num_devices} devices." | ||
| assert all((d >= 0 and d < len(mesh_shape)) for d in partition_spec if d), \ | ||
| assert mesh.size() == num_devices, \ | ||
| f"{mesh.mesh_shape} is not mappable over {num_devices} devices." | ||
| assert all((d >= 0 and d < len(mesh.mesh_shape)) for d in partition_spec if d), \ | ||
| f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape." | ||
| # TODO(yeounoh) allow unspecified ranks (len(partition_spec) <= len(t.shape)), | ||
| # for replication. For now, all input rank sharding should be specified. | ||
|
|
@@ -53,15 +114,12 @@ def mark_sharding(t: Union[torch.Tensor, | |
| assert len(dims) == len(np.unique(dims)), \ | ||
| f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." | ||
|
|
||
| device_ids = np.array(range(num_devices)) | ||
| tile_assignment = device_ids.reshape(mesh_shape).tolist() | ||
|
|
||
| tile_assignment = mesh.get_logical_mesh().tolist() | ||
| manual, replicated, partial = False, False, False | ||
| if all(d is None for d in partition_spec): | ||
| replicated = True | ||
| elif any(d is None for d in partition_spec): | ||
| partial = True | ||
|
|
||
| # TODO(yeounoh) suport partially replicated sharding. | ||
| assert not partial, "Partial replication is currently not supported." | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@will-cromar I think PJRT-GPU single core is ready now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's blocked from the our SPMD side, once we support TPU, the transition should be easier to GPU -- maybe sometime next year once we are done with the basic/core SPMD features?