Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 43 additions & 21 deletions test/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,60 @@
import torch_xla.utils.utils as xu
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.pjrt import using_pjrt


@unittest.skipIf((os.getenv('PJRT_DEVICE') == "") or (
xm.get_xla_supported_devices("GPU") is not None
), "PyTorch/XLA SPMD requires PJRT_DEVICE={CPU, TPU}, GPU is currently not supported."
)
@unittest.skipIf(not using_pjrt() or xm.get_xla_supported_devices("GPU"),
f"Requires PJRT_DEVICE set to `TPU` or `CPU`.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@will-cromar I think PJRT-GPU single core is ready now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's blocked from the our SPMD side, once we support TPU, the transition should be easier to GPU -- maybe sometime next year once we are done with the basic/core SPMD features?

class XlaShardingTest(unittest.TestCase):

n_devices = 0
device_ids = None

@classmethod
def setUpClass(cls):
cls.n_devices = len(xm.get_xla_supported_devices())
cls.device_ids = np.array(range(cls.n_devices))

def _get_mesh(self, mesh_shape, device_ids=None):
if device_ids is None:
device_ids = self.device_ids
assert len(device_ids) == self.n_devices
return xs.Mesh(device_ids, mesh_shape)

def test_xla_sharded_tensor(self):
n_devices = xm.xrt_world_size()
mesh_shape = (1, n_devices)
partition_spec = (0, 1)
xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]],
dtype=torch.float,
device=xm.xla_device())
xst1 = xs.mark_sharding(xt1, mesh_shape, partition_spec)
xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)),
partition_spec)

# TODO(244003536) add more tests for XLAShardedTensror.
self.assertTrue(isinstance(xst1, XLAShardedTensor))

def test_custom_tile_assignment(self):
xt = torch.randn(10, 20).to(device=xm.xla_device())
mesh_shape = (1, self.n_devices)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the tests have all devices mapped to a single axis - is there anything stopping us from using e.g. mesh_shape = (2, self.n_devices / 2)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but for the unit testing a flat mesh is easier to work with since we don't know how many devices we would have (e.g., for CPU, we will have 1).

device_ids = np.flip(self.device_ids)
mesh = self._get_mesh(mesh_shape, device_ids)
xs.mark_sharding(xt, mesh, (0, 1))
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join([
str(i) for i in reversed(range(self.n_devices))
])) if self.n_devices > 1 else '{maximal device=0}'
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))

def test_mark_sharding_2d(self):
t1 = torch.randn(1, 128, device='cpu')
t2 = torch.randn(1, 128, device='cpu')
expected = t1 @ t2.T

xt1 = t1.to(xm.xla_device())
xt2 = t2.to(xm.xla_device())
n_devices = xm.xrt_world_size()
xs.mark_sharding(xt1, (1, n_devices), (0, 1))
annotation = '{devices=[1,%d]%s}' % (n_devices, ','.join(
[str(i)
for i in range(n_devices)])) if n_devices > 1 else '{maximal device=0}'
xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), (0, 1))
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join([
str(i) for i in range(self.n_devices)
])) if self.n_devices > 1 else '{maximal device=0}'
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt1))

actual = (xt1 @ xt2.T).cpu()
Expand All @@ -53,28 +75,28 @@ def test_mark_sharding_4d(self):
expected = t + t

xt = t.to(xm.xla_device())
n_devices = xm.xrt_world_size()
xs.mark_sharding(xt, (1, 1, 1, n_devices), (0, 1, 2, 3))
annotation = '{devices=[1,1,1,%d]%s}' % (n_devices, ','.join(
[str(i)
for i in range(n_devices)])) if n_devices > 1 else '{maximal device=0}'
xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)),
(0, 1, 2, 3))
annotation = '{devices=[1,1,1,%d]%s}' % (self.n_devices, ','.join([
str(i) for i in range(self.n_devices)
])) if self.n_devices > 1 else '{maximal device=0}'
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))

actual = (xt + xt).cpu()
self.assertTrue(torch.allclose(expected, actual))

def test_clear_sharding(self):
xt = torch.randn(2, 4, 8, 16).to(xm.xla_device())
n_devices = xm.xrt_world_size()
xs.mark_sharding(xt, (1, 1, 1, n_devices), (0, 1, 2, 3))
xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)),
(0, 1, 2, 3))
self.assertTrue(torch_xla._XLAC._get_xla_sharding_spec(xt))
xs.clear_sharding(xt)
self.assertFalse(torch_xla._XLAC._get_xla_sharding_spec(xt))

def test_deep_copy(self):
xt = torch.randn(2, 4, 8, 16).to(xm.xla_device())
n_devices = xm.xrt_world_size()
xs.mark_sharding(xt, (1, 1, 1, n_devices), (0, 1, 2, 3))
xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)),
(0, 1, 2, 3))
xt2 = copy.deepcopy(xt)
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(xt),
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
// host communications. This means that we may need to manually shard
// across global devices for multi-host training.
std::vector<std::string> local_devices =
xla::ComputationClient::Get()->GetLocalDevices();
xla::ComputationClient::Get()->GetAllDevices();
// Shards the input tensors with padding, to split evenly.
// The execution requires consistent shard sizes, and the zero-padded
// values should be ignored.
Expand Down
36 changes: 20 additions & 16 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,25 +334,29 @@ std::vector<at::Tensor> ShardingUtil::ShardTensor(
int64_t core = sharding.tile_assignment_devices()[i];
shards[core] = shard.contiguous(at::MemoryFormat::Contiguous);
}
} else if ((sharding.type() == xla::OpSharding::MANUAL) ||
(sharding.type() == xla::OpSharding::TUPLE)) {
TF_LOG(ERROR) << "Unsupported OpSharidng type " << sharding.type();
}

// Zero-pad to the right to ensure the sizes are even
if (shards.size() > 0 && padded) {
for (size_t i = 1; i < shards.size(); ++i) {
std::vector<long> pads;
for (size_t j = 0; j < shards[i].sizes().size(); ++j) {
XLA_CHECK_GE(shards[0].sizes().at(j), shards[i].sizes().at(j));
pads.push_back(shards[0].sizes().at(j) - shards[i].sizes().at(j));
pads.push_back(0); // no padding on lhs
// Zero-pad to the right to ensure the sizes are even
if (shards.size() > 0 && padded) {
for (size_t i = 1; i < shards.size(); ++i) {
std::vector<long> pads;
for (size_t j = 0; j < shards[i].sizes().size(); ++j) {
XLA_CHECK_GE(
shards[sharding.tile_assignment_devices()[0]].sizes().at(j),
shards[i].sizes().at(j));
pads.push_back(
shards[sharding.tile_assignment_devices()[0]].sizes().at(j) -
shards[i].sizes().at(j));
pads.push_back(0); // no padding on lhs
}
// Padding starts from the last dimension
std::reverse(pads.begin(), pads.end());
shards[i] = at::constant_pad_nd(
shards[i], c10::IntArrayRef(pads.data(), pads.size()), 0);
}
// Padding starts from the last dimension
std::reverse(pads.begin(), pads.end());
shards[i] = at::constant_pad_nd(
shards[i], c10::IntArrayRef(pads.data(), pads.size()), 0);
}
} else if ((sharding.type() == xla::OpSharding::MANUAL) ||
(sharding.type() == xla::OpSharding::TUPLE)) {
TF_LOG(ERROR) << "Unsupported OpSharidng type " << sharding.type();
}
return shards;
}
Expand Down
102 changes: 80 additions & 22 deletions torch_xla/experimental/xla_sharding.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,110 @@
from collections import OrderedDict
import torch
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.pjrt import requires_pjrt

import numpy as np
from typing import Tuple, Union
from typing import Tuple, Union, List


def mark_sharding(t: Union[torch.Tensor,
XLAShardedTensor], mesh_shape: Tuple[int],
class Mesh:
"""Describe the logical XLA device topology mesh and the underlying resources.

Args:
device_ids (Union[np.ndarray, List]): A raveled list of devices (IDs) in a custom order. The list is reshaped
to an `mesh_shape` array, filling the elements using C-like index order.

mesh_shape (Tuple[int, ...]): A int tuple describing the logical topology shape
of the device mesh, and each element describes the number of devices in
the corresponding axis.

axis_names (Tuple[str, ...]): A sequence of resource axis names to be assigned to the dimensions
of the `devices` argument. Its length should match the rank of `devices`.

Example:
—------------------------------
mesh_shape = (4, 2)
num_devices = len(xm.get_xla_supported_devices())
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
mesh.get_logical_mesh()
>> array([[0, 1],
[2, 3],
[4, 5],
[6, 7]])
mesh.shape()
>> OrderedDict([('x', 4), ('y', 2)])
"""

device_ids: np.ndarray
mesh_shape: Tuple[int, ...]
axis_names: Tuple[str, ...]

def __init__(self,
device_ids: Union[np.ndarray, List],
mesh_shape: Tuple[int, ...],
axis_names: Tuple[str, ...] = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious - how will axis_names be used long-term? Is it just for annotating the mesh?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, mesh axis annotation is useful since it makes the annotation logic more readable. We can also build a partitioning rule based on the axis name, instead of int indices.

if not isinstance(device_ids, np.ndarray):
device_ids = np.array(device_ids)
assert (axis_names is None) or (len(mesh_shape) == len(axis_names))
assert (len(device_ids) == np.prod(mesh_shape))
assert len(device_ids) == len(np.unique(device_ids))
self.device_ids = device_ids
self.mesh_shape = mesh_shape
self.axis_names = axis_names
assert all(d < self.size() for d in device_ids)

def size(self):
return np.prod(self.mesh_shape)

def shape(self):
return OrderedDict(
(name, size) for name, size in zip(self.axis_name, self.mesh_shape))

def get_logical_mesh(self):
return self.device_ids.reshape(self.mesh_shape)


@requires_pjrt
def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
partition_spec: Tuple[Union[int, None]]) -> XLAShardedTensor:
"""
Annotates the tensor provided with XLA partition spec. Internally,
it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass.
Args:
t (Union[torch.Tensor, XLAShardedTensor]): input tensor to be annotated with partition_sepc.
mesh_shape (Tuple[Union[int, None]]): A int tuple describing the logical topology
of the device mesh, and each element describes the number of devices in
the corresponding axis.

mesh (Mesh): describes the logical XLA device topology and the underlying device IDs.

partition_spec (Tuple[int, None]): A tuple of device_mesh dimension index or `None`.
This specifies how each input rank is sharded (index to mesh_shape) or replicated (None).
For example, we can shard an 8x10 tensor 4-way row-wise, and replicate column-wise.
>> input = torch.randn(8, 10)
>> mesh_shape = (4, 2)
>> assert np.prod(mesh_shape) == xm.xrt_world_size()
>> partition_spec = (0, None)
>> assert len(input.shape) == len(partition_spec)

Examples
—------------------------------
mesh_shape = (4, 2)
input = torch.randn(8, 32).to(xm.xla_device())
num_devices = len(xm.get_xla_supported_devices())
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

# 4-way data parallel
input = xs.mark_sharding(input, mesh_shape, (0, None))
linear = nn.Linear(32, 10).to(xm.xla_device())
input = torch.randn(8, 32).to(xm.xla_device())
xs.mark_sharding(input, mesh, (0, None))

# 2-way model parallel
linear.weight = xs.mark_sharding(linear.weight, device_mesh, (None, 1))
output = linear(input)
# full replication
output = xs.mark_sharding(output, device_mesh, (None, None))
linear = nn.Linear(32, 10).to(xm.xla_device())
xs.mark_sharding(linear.weight, mesh, (None, 1))
"""
num_devices = len(xm.get_xla_supported_devices())
assert num_devices > 0, "This requires XLA supported device(s)."
assert np.prod(mesh_shape) == num_devices, \
f"{mesh_shape} is not mappable over {num_devices} devices."
assert all((d >= 0 and d < len(mesh_shape)) for d in partition_spec if d), \
assert mesh.size() == num_devices, \
f"{mesh.mesh_shape} is not mappable over {num_devices} devices."
assert all((d >= 0 and d < len(mesh.mesh_shape)) for d in partition_spec if d), \
f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
# TODO(yeounoh) allow unspecified ranks (len(partition_spec) <= len(t.shape)),
# for replication. For now, all input rank sharding should be specified.
Expand All @@ -53,15 +114,12 @@ def mark_sharding(t: Union[torch.Tensor,
assert len(dims) == len(np.unique(dims)), \
f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."

device_ids = np.array(range(num_devices))
tile_assignment = device_ids.reshape(mesh_shape).tolist()

tile_assignment = mesh.get_logical_mesh().tolist()
manual, replicated, partial = False, False, False
if all(d is None for d in partition_spec):
replicated = True
elif any(d is None for d in partition_spec):
partial = True

# TODO(yeounoh) suport partially replicated sharding.
assert not partial, "Partial replication is currently not supported."

Expand Down