-
Notifications
You must be signed in to change notification settings - Fork 559
Support multihost SPMD execution #4573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -175,26 +175,43 @@ xla::HloModuleProto ShardingUtil::SpmdPartitioningPass( | |
| return module.get()->ToProto(); | ||
| } | ||
|
|
||
| // Builds a map from the device's global ordinal to its index in the `devices` | ||
| // array. This is used by `ShardTensor` and `InputHandler` to ensure the | ||
| // order of the output corresponds to the order of the `devices`, which can be | ||
| // arbitrarily set by the caller. | ||
| static std::unordered_map<int, int> build_index_map( | ||
| const std::vector<std::string>& devices) { | ||
| std::unordered_map<int, int> device_index; | ||
| for (int i = 0; i < devices.size(); ++i) { | ||
| int global_ordinal = ParseDeviceString(devices[i]).ordinal(); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The first global device gets the local index 0, so the order of the input devices list is important. Is this a correct understanding? Can we add some comments on this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The first device in the list gets local index 0, but the order of the global ordinals within |
||
| device_index[global_ordinal] = i; | ||
| } | ||
| return device_index; | ||
| } | ||
|
|
||
| std::vector<std::vector<xla::ComputationClient::DataPtr>> | ||
| ShardingUtil::InputHandler( | ||
| std::vector<xla::ComputationClient::DataPtr> arguments, | ||
| std::vector<std::string> devices) { | ||
| std::vector<std::vector<xla::ComputationClient::DataPtr>> arguments_by_device( | ||
| devices.size(), | ||
| std::vector<xla::ComputationClient::DataPtr>(arguments.size())); | ||
| auto device_index = build_index_map(devices); | ||
|
|
||
| for (int64_t argument_i = 0; argument_i < arguments.size(); ++argument_i) { | ||
| auto shards = | ||
| xla::ComputationClient::Get()->GetDataShards(arguments[argument_i]); | ||
| if (shards.size() > 1) { | ||
| // Input is sharded across addressable devices | ||
| for (auto shard : shards) { | ||
| int64_t device_i = ParseDeviceString(shard->device()).ordinal(); | ||
| int global_ordinal = ParseDeviceString(shard->device()).ordinal(); | ||
| int device_i = device_index[global_ordinal]; | ||
| arguments_by_device[device_i][argument_i] = shard; | ||
| } | ||
| } else { | ||
| // Input is replicated across addressable devices | ||
| int64_t source_device_i = | ||
| ParseDeviceString(shards[0]->device()).ordinal(); | ||
| int global_ordinal = ParseDeviceString(shards[0]->device()).ordinal(); | ||
| int source_device_i = device_index[global_ordinal]; | ||
| arguments_by_device[source_device_i][argument_i] = shards[0]; | ||
| for (int64_t device_i = 0; device_i < devices.size(); ++device_i) { | ||
| if (device_i != source_device_i) { | ||
|
|
@@ -214,19 +231,30 @@ std::vector<at::Tensor> ShardingUtil::ShardTensor( | |
| const std::vector<std::string>& devices, bool padded) { | ||
| TF_LOG(INFO) << "ShardTensor with sharding type(" << sharding.type() << ")..." | ||
| << std::endl; | ||
| auto device_index = build_index_map(devices); | ||
| std::vector<at::Tensor> shards(devices.size()); | ||
| if (sharding.type() == xla::OpSharding::REPLICATED) { | ||
| std::fill_n(shards.begin(), shards.size(), tensor); | ||
| } else if (sharding.type() == xla::OpSharding::OTHER) { | ||
| XLA_CHECK_EQ(devices.size(), sharding.tile_assignment_devices().size()) | ||
| << "Invalid sharding tile_assignment_devices.size(): expected " | ||
| << devices.size() << ", actual " | ||
| << sharding.tile_assignment_devices().size(); | ||
| XLA_CHECK(sharding.tile_shape().dimensions_size() <= 2); | ||
| XLA_CHECK(tensor.sizes().size() >= sharding.tile_shape().dimensions_size()); | ||
|
|
||
| auto tile_shape = sharding.tile_assignment_dimensions(); | ||
|
|
||
| // `partition_len[j]` is the size of dimension `j` in the resulting shard. | ||
| std::vector<int64_t> partition_len; | ||
| for (int j = 0; j < tile_shape.size(); j++) { | ||
| partition_len.push_back(tensor.sizes()[j] / tile_shape[j] + | ||
| (tensor.sizes()[j] % tile_shape[j] != 0)); | ||
| } | ||
|
|
||
| for (size_t i = 0; i < sharding.tile_assignment_devices().size(); i++) { | ||
| int64_t core = sharding.tile_assignment_devices()[i]; | ||
| if (device_index.find(core) == device_index.end()) { | ||
| // Skip any shards whose device is not part of the `devices` list. | ||
| continue; | ||
| } | ||
|
|
||
| // Given the shard's row-major index `i`, we need to calculate shard's | ||
| // coordinates (n_0, ..., n_d) in the tiling to generate the index slices. | ||
| // Using `N_j = tile_shape[j]` and `0 <= n_j < N_j`, the following | ||
|
|
@@ -239,32 +267,26 @@ std::vector<at::Tensor> ShardingUtil::ShardTensor( | |
| std::vector<at::indexing::TensorIndex> indices; | ||
| for (int j = tile_shape.size() - 1; j >= 0; j--) { | ||
| int64_t n_j = offset % tile_shape[j]; | ||
| int64_t partition_len = tensor.sizes()[j] / tile_shape[j] + | ||
| (tensor.sizes()[j] % tile_shape[j] != 0); | ||
| auto slice = | ||
| at::indexing::Slice(n_j * partition_len, (n_j + 1) * partition_len); | ||
| auto slice = at::indexing::Slice(n_j * partition_len[j], | ||
| (n_j + 1) * partition_len[j]); | ||
| indices.push_back(at::indexing::TensorIndex(slice)); | ||
| offset /= tile_shape[j]; | ||
| } | ||
| std::reverse(indices.begin(), indices.end()); | ||
| at::Tensor shard = | ||
| tensor.index(c10::ArrayRef<at::indexing::TensorIndex>(indices)); | ||
|
|
||
| int64_t core = sharding.tile_assignment_devices()[i]; | ||
| shards[core] = shard.contiguous(at::MemoryFormat::Contiguous); | ||
| shards[device_index[core]] = | ||
| shard.contiguous(at::MemoryFormat::Contiguous); | ||
| } | ||
|
|
||
| // Zero-pad to the right to ensure the sizes are even | ||
| if (shards.size() > 0 && padded) { | ||
| for (size_t i = 1; i < shards.size(); ++i) { | ||
| for (size_t i = 0; i < shards.size(); ++i) { | ||
| std::vector<long> pads; | ||
| for (size_t j = 0; j < shards[i].sizes().size(); ++j) { | ||
| XLA_CHECK_GE( | ||
| shards[sharding.tile_assignment_devices()[0]].sizes().at(j), | ||
| shards[i].sizes().at(j)); | ||
| pads.push_back( | ||
| shards[sharding.tile_assignment_devices()[0]].sizes().at(j) - | ||
| shards[i].sizes().at(j)); | ||
| for (size_t j = 0; j < partition_len.size(); ++j) { | ||
| XLA_CHECK_GE(partition_len[j], shards[i].sizes().at(j)); | ||
| pads.push_back(partition_len[j] - shards[i].sizes().at(j)); | ||
| pads.push_back(0); // no padding on lhs | ||
| } | ||
| // Padding starts from the last dimension | ||
|
|
@@ -275,7 +297,7 @@ std::vector<at::Tensor> ShardingUtil::ShardTensor( | |
| } | ||
| } else if ((sharding.type() == xla::OpSharding::MANUAL) || | ||
| (sharding.type() == xla::OpSharding::TUPLE)) { | ||
| TF_LOG(ERROR) << "Unsupported OpSharidng type " << sharding.type(); | ||
| TF_LOG(ERROR) << "Unsupported OpSharding type " << sharding.type(); | ||
| } | ||
| return shards; | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| import torch | ||
| import torch_xla | ||
| import torch_xla.core.xla_model as xm | ||
| import torch_xla.experimental.pjrt as pjrt | ||
| from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor | ||
| from torch_xla.experimental.pjrt import requires_pjrt | ||
|
|
||
|
|
@@ -88,7 +89,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, | |
| Examples | ||
| —------------------------------ | ||
| mesh_shape = (4, 2) | ||
| num_devices = len(xm.get_xla_supported_devices()) | ||
| num_devices = pjrt.global_device_count() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great :) |
||
| device_ids = np.array(range(num_devices)) | ||
| mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) | ||
|
|
||
|
|
@@ -100,7 +101,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, | |
| linear = nn.Linear(32, 10).to(xm.xla_device()) | ||
| xs.mark_sharding(linear.weight, mesh, (None, 1)) | ||
| """ | ||
| num_devices = len(xm.get_xla_supported_devices()) | ||
| num_devices = pjrt.global_device_count() | ||
| assert num_devices > 0, "This requires XLA supported device(s)." | ||
| assert mesh.size() == num_devices, \ | ||
| f"{mesh.mesh_shape} is not mappable over {num_devices} devices." | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does
GetLocalDevices()return local devices with global ordinals? If so, let's leave a comment.