Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,39 @@ TEST_F(XLAShardingTest, ShardTensor) {
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({10, 1, 4, 4, 2}));
}

TEST_F(XLAShardingTest, ShardTensorMultiHost) {
std::vector<std::string> devices = {"TPU:4", "TPU:5", "TPU:6", "TPU:7"};

// 2D tiled, The first dim is halved and the last replicated.
at::Tensor tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat));
xla::Array2D<int64_t> mesh({
{4, 5, 0, 1},
{6, 7, 2, 3},
});
auto sharding = xla::HloSharding::Tile(mesh).ToProto();

// For devices at the start of the mesh, all shards should have the same
// unpadded shape.
auto shards =
ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false);
EXPECT_EQ(shards.size(), 4);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({4, 2, 4}));
EXPECT_EQ(shards[3].sizes(), c10::ArrayRef<long>({4, 2, 4}));

// When this host's devices are at the end of the mesh, the last shard should
// be smaller in dim=2 because it's not evenly divisible.
mesh = xla::Array2D<int64_t>({
{0, 1, 4, 5},
{2, 3, 6, 7},
});
sharding = xla::HloSharding::Tile(mesh).ToProto();
shards =
ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false);
EXPECT_EQ(shards.size(), 4);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({4, 2, 4}));
EXPECT_EQ(shards[3].sizes(), c10::ArrayRef<long>({4, 1, 4}));
}

TEST_F(XLAShardingTest, EqualShardingSpecs) {
XLATensor::ShardingSpec tiled_2d(xla::HloSharding::Tile({
{0, 1, 2, 3},
Expand Down Expand Up @@ -184,15 +217,13 @@ TEST_F(XLAShardingTest, InputHandler) {
std::vector<at::Tensor> tensors(2);
std::fill_n(tensors.begin(), tensors.size(),
at::ones({8, 8}, at::TensorOptions(at::kFloat)));
std::vector<std::string> devices(2);
std::fill_n(devices.begin(), devices.size(), GetDefaultDevice()->toString());
std::vector<std::string> devices = {"TPU:0", "TPU:1"};
std::vector<XLATensor::ShardingSpecPtr> shardings = {
nullptr, std::make_shared<XLATensor::ShardingSpec>(
xla::HloSharding::Replicate().ToProto())};
std::vector<torch::lazy::BackendDataPtr> tensors_data =
CreateTensorsData(tensors, shardings, devices);

devices = xla::ComputationClient::Get()->GetLocalDevices();
std::vector<xla::ComputationClient::DataPtr> arguments =
UnwrapXlaData(tensors_data);
auto arguments_by_device = ShardingUtil::InputHandler(arguments, devices);
Expand Down
22 changes: 9 additions & 13 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -931,30 +931,26 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
std::vector<xla::ComputationClient::DataPtr> new_handles; // out
if (shardings[i] != nullptr) {
xla::OpSharding sharding = shardings[i]->sharding;
// TODO(yeounoh) PJRT runs a process per host for SPMD and without cross
// host communications. This means that we may need to manually shard
// across global devices for multi-host training.
// GetLocalDevices returns the list of local devices specified by their
// global ordinals (e.g. ["TPU:4", "TPU:5", "TPU:6", "TPU:7"]).
std::vector<std::string> local_devices =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does GetLocalDevices() return local devices with global ordinals? If so, let's leave a comment.

xla::ComputationClient::Get()->GetAllDevices();
xla::ComputationClient::Get()->GetLocalDevices();
// Shards the input tensors with padding, to split evenly.
// The execution requires consistent shard sizes, and the zero-padded
// values should be ignored.
std::vector<at::Tensor> shards = ShardingUtil::ShardTensor(
std::vector<at::Tensor> local_shards = ShardingUtil::ShardTensor(
tensors[i], sharding, local_devices, /*padded=*/true);

for (int64_t j = 0; j < shards.size(); ++j) {
int64_t ordinal = (sharding.type() == xla::OpSharding::OTHER)
? sharding.tile_assignment_devices()[j]
: j;
auto shard_device = ParseDeviceString(local_devices[ordinal]);
for (int64_t j = 0; j < local_shards.size(); ++j) {
auto shard_device = ParseDeviceString(local_devices[j]);
auto shard_shape =
CreateComputationShapeFromTensor(shards[j], &shard_device);
CreateComputationShapeFromTensor(local_shards[j], &shard_device);
auto populate_fn =
[&, j, shard_device](
const xla::ComputationClient::TensorSource& source_tensor,
void* dest_buffer, size_t dest_buffer_size) {
PopulateTensorBuffer(shards[j], source_tensor.shape, dest_buffer,
dest_buffer_size, shard_device);
PopulateTensorBuffer(local_shards[j], source_tensor.shape,
dest_buffer, dest_buffer_size, shard_device);
};
source_tensors.emplace_back(std::move(shard_shape),
shard_device.toString(),
Expand Down
1 change: 0 additions & 1 deletion torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,6 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph(
std::vector<torch::lazy::BackendDataPtr> results;
// Execute replicated if the compiled computation is partitioned.
if (async->cached_computation->is_sharded) {
// TODO(yeounoh) use local devices and verify with the pod execution.
std::vector<std::string> devices =
xla::ComputationClient::Get()->GetLocalDevices();
std::vector<std::vector<xla::ComputationClient::DataPtr>>
Expand Down
66 changes: 44 additions & 22 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,26 +175,43 @@ xla::HloModuleProto ShardingUtil::SpmdPartitioningPass(
return module.get()->ToProto();
}

// Builds a map from the device's global ordinal to its index in the `devices`
// array. This is used by `ShardTensor` and `InputHandler` to ensure the
// order of the output corresponds to the order of the `devices`, which can be
// arbitrarily set by the caller.
static std::unordered_map<int, int> build_index_map(
const std::vector<std::string>& devices) {
std::unordered_map<int, int> device_index;
for (int i = 0; i < devices.size(); ++i) {
int global_ordinal = ParseDeviceString(devices[i]).ordinal();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first global device gets the local index 0, so the order of the input devices list is important. Is this a correct understanding? Can we add some comments on this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first device in the list gets local index 0, but the order of the global ordinals within devices doesn't matter. I'll add some more documentation around this.

device_index[global_ordinal] = i;
}
return device_index;
}

std::vector<std::vector<xla::ComputationClient::DataPtr>>
ShardingUtil::InputHandler(
std::vector<xla::ComputationClient::DataPtr> arguments,
std::vector<std::string> devices) {
std::vector<std::vector<xla::ComputationClient::DataPtr>> arguments_by_device(
devices.size(),
std::vector<xla::ComputationClient::DataPtr>(arguments.size()));
auto device_index = build_index_map(devices);

for (int64_t argument_i = 0; argument_i < arguments.size(); ++argument_i) {
auto shards =
xla::ComputationClient::Get()->GetDataShards(arguments[argument_i]);
if (shards.size() > 1) {
// Input is sharded across addressable devices
for (auto shard : shards) {
int64_t device_i = ParseDeviceString(shard->device()).ordinal();
int global_ordinal = ParseDeviceString(shard->device()).ordinal();
int device_i = device_index[global_ordinal];
arguments_by_device[device_i][argument_i] = shard;
}
} else {
// Input is replicated across addressable devices
int64_t source_device_i =
ParseDeviceString(shards[0]->device()).ordinal();
int global_ordinal = ParseDeviceString(shards[0]->device()).ordinal();
int source_device_i = device_index[global_ordinal];
arguments_by_device[source_device_i][argument_i] = shards[0];
for (int64_t device_i = 0; device_i < devices.size(); ++device_i) {
if (device_i != source_device_i) {
Expand All @@ -214,19 +231,30 @@ std::vector<at::Tensor> ShardingUtil::ShardTensor(
const std::vector<std::string>& devices, bool padded) {
TF_LOG(INFO) << "ShardTensor with sharding type(" << sharding.type() << ")..."
<< std::endl;
auto device_index = build_index_map(devices);
std::vector<at::Tensor> shards(devices.size());
if (sharding.type() == xla::OpSharding::REPLICATED) {
std::fill_n(shards.begin(), shards.size(), tensor);
} else if (sharding.type() == xla::OpSharding::OTHER) {
XLA_CHECK_EQ(devices.size(), sharding.tile_assignment_devices().size())
<< "Invalid sharding tile_assignment_devices.size(): expected "
<< devices.size() << ", actual "
<< sharding.tile_assignment_devices().size();
XLA_CHECK(sharding.tile_shape().dimensions_size() <= 2);
XLA_CHECK(tensor.sizes().size() >= sharding.tile_shape().dimensions_size());

auto tile_shape = sharding.tile_assignment_dimensions();

// `partition_len[j]` is the size of dimension `j` in the resulting shard.
std::vector<int64_t> partition_len;
for (int j = 0; j < tile_shape.size(); j++) {
partition_len.push_back(tensor.sizes()[j] / tile_shape[j] +
(tensor.sizes()[j] % tile_shape[j] != 0));
}

for (size_t i = 0; i < sharding.tile_assignment_devices().size(); i++) {
int64_t core = sharding.tile_assignment_devices()[i];
if (device_index.find(core) == device_index.end()) {
// Skip any shards whose device is not part of the `devices` list.
continue;
}

// Given the shard's row-major index `i`, we need to calculate shard's
// coordinates (n_0, ..., n_d) in the tiling to generate the index slices.
// Using `N_j = tile_shape[j]` and `0 <= n_j < N_j`, the following
Expand All @@ -239,32 +267,26 @@ std::vector<at::Tensor> ShardingUtil::ShardTensor(
std::vector<at::indexing::TensorIndex> indices;
for (int j = tile_shape.size() - 1; j >= 0; j--) {
int64_t n_j = offset % tile_shape[j];
int64_t partition_len = tensor.sizes()[j] / tile_shape[j] +
(tensor.sizes()[j] % tile_shape[j] != 0);
auto slice =
at::indexing::Slice(n_j * partition_len, (n_j + 1) * partition_len);
auto slice = at::indexing::Slice(n_j * partition_len[j],
(n_j + 1) * partition_len[j]);
indices.push_back(at::indexing::TensorIndex(slice));
offset /= tile_shape[j];
}
std::reverse(indices.begin(), indices.end());
at::Tensor shard =
tensor.index(c10::ArrayRef<at::indexing::TensorIndex>(indices));

int64_t core = sharding.tile_assignment_devices()[i];
shards[core] = shard.contiguous(at::MemoryFormat::Contiguous);
shards[device_index[core]] =
shard.contiguous(at::MemoryFormat::Contiguous);
}

// Zero-pad to the right to ensure the sizes are even
if (shards.size() > 0 && padded) {
for (size_t i = 1; i < shards.size(); ++i) {
for (size_t i = 0; i < shards.size(); ++i) {
std::vector<long> pads;
for (size_t j = 0; j < shards[i].sizes().size(); ++j) {
XLA_CHECK_GE(
shards[sharding.tile_assignment_devices()[0]].sizes().at(j),
shards[i].sizes().at(j));
pads.push_back(
shards[sharding.tile_assignment_devices()[0]].sizes().at(j) -
shards[i].sizes().at(j));
for (size_t j = 0; j < partition_len.size(); ++j) {
XLA_CHECK_GE(partition_len[j], shards[i].sizes().at(j));
pads.push_back(partition_len[j] - shards[i].sizes().at(j));
pads.push_back(0); // no padding on lhs
}
// Padding starts from the last dimension
Expand All @@ -275,7 +297,7 @@ std::vector<at::Tensor> ShardingUtil::ShardTensor(
}
} else if ((sharding.type() == xla::OpSharding::MANUAL) ||
(sharding.type() == xla::OpSharding::TUPLE)) {
TF_LOG(ERROR) << "Unsupported OpSharidng type " << sharding.type();
TF_LOG(ERROR) << "Unsupported OpSharding type " << sharding.type();
}
return shards;
}
Expand Down
18 changes: 11 additions & 7 deletions torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,23 @@ class ShardingUtil {

// This reshuffles arguments (sharded or replicated) on the devices. The
// size of the arguments vector must match that of the sharding_specs.
// The the returned arguments will be in 1:1 correspondence with the `devices`
// vector, so the `i`th result will belong on the `i`th device.
// TODO(yeounoh) avoiding pre-loading of the unpartitioned input arguments
// might improve the performance and save the bandwidth.
static std::vector<std::vector<xla::ComputationClient::DataPtr>> InputHandler(
std::vector<xla::ComputationClient::DataPtr> arguments,
std::vector<std::string> devices);

// Shard a tensor and returns the sharded tensors based on the `sharding`
// spec. REPLICATED sharding should result in shards identical to the input;
// OTHERS (tiled) sharding result in shards where each data dimension is
// sharded across devices along the same dimension in the `tile_assignment`;
// the returned tensor shards vector is indexed by the device IDs. There is no
// data duplication. Shards are not padded in case the input tensor is not
// evenly partitionable, unless `padded` is set.
// Shard a tensor and returns the sharded tensors which belong on `devices`
// based on the `sharding` spec. REPLICATED sharding should result in shards
// identical to the input; OTHERS (tiled) sharding result in shards where
// each data dimension is sharded across devices along the same dimension in
// the `tile_assignment`; the returned tensor shards vector is indexed by the
// device IDs. There is no data duplication. Shards are not padded in case the
// input tensor is not evenly partitionable, unless `padded` is set.
// The the returned tensors will be in 1:1 correspondence with the `devices`
// vector, so the `i`th result will belong on the `i`th device.
static std::vector<at::Tensor> ShardTensor(
const at::Tensor& tensor, const xla::OpSharding sharding,
const std::vector<std::string>& devices, bool padded = true);
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/experimental/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.experimental.pjrt as pjrt
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.pjrt import requires_pjrt

Expand Down Expand Up @@ -88,7 +89,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
Examples
—------------------------------
mesh_shape = (4, 2)
num_devices = len(xm.get_xla_supported_devices())
num_devices = pjrt.global_device_count()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great :)

device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

Expand All @@ -100,7 +101,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
linear = nn.Linear(32, 10).to(xm.xla_device())
xs.mark_sharding(linear.weight, mesh, (None, 1))
"""
num_devices = len(xm.get_xla_supported_devices())
num_devices = pjrt.global_device_count()
assert num_devices > 0, "This requires XLA supported device(s)."
assert mesh.size() == num_devices, \
f"{mesh.mesh_shape} is not mappable over {num_devices} devices."
Expand Down