Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions test/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,20 @@ def test_deep_copy(self):
torch_xla._XLAC._get_xla_sharding_spec(xt),
torch_xla._XLAC._get_xla_sharding_spec(xt2))

def test_mark_step_with_sharding(self):
xt = torch.ones(2, 2).to(xm.xla_device())
xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (0, 1))
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt)
xm.mark_step() # mark_step should preserve the sharding
self.assertEqual(sharding_spec, torch_xla._XLAC._get_xla_sharding_spec(xt))

def test_inplace_add_with_sharding(self):
xt = torch.ones(2, 2).to(xm.xla_device())
xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (0, 1))
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt)
xt.add_(1) # inplace update should preserve the sharding
self.assertEqual(sharding_spec, torch_xla._XLAC._get_xla_sharding_spec(xt))


class VirtualDeviceTest(XlaShardingTest):

Expand Down
19 changes: 12 additions & 7 deletions third_party/xla_client/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,16 +367,21 @@ PjRtComputationClient::ExecuteReplicated(

std::vector<std::vector<ComputationClient::DataPtr>> data_handles;
data_handles.reserve(results.size());
for (auto& result : results) {
for (int32_t i = 0; i < results.size(); ++i) {
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]);
XLA_CHECK(pjrt_device->IsAddressable())
<< pjrt_device->DebugString() << " is not addressable.";

std::vector<ComputationClient::DataPtr> datas;
datas.reserve(result.size());
for (int32_t i = 0; i < result.size(); ++i) {
std::unique_ptr<xla::PjRtBuffer> buffer = std::move(result[i]);
datas.reserve(results[i].size());
for (int32_t j = 0; j < results[i].size(); ++j) {
std::unique_ptr<xla::PjRtBuffer> buffer = std::move(results[i][j]);
XLA_CHECK(pjrt_device == buffer->device())
<< "Exepcted device: " << pjrt_device->DebugString()
<< " vs. actual device: " << buffer->device()->DebugString();

std::shared_ptr<PjRtData> data = std::make_shared<PjRtData>(
devices[i], buffer->on_device_shape(),
std::move(buffer));

devices[i], buffer->on_device_shape(), std::move(buffer));
datas.push_back(data);
}
data_handles.push_back(datas);
Expand Down
24 changes: 14 additions & 10 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ void XLATensor::ClearShardingSpec() {
torch::lazy::Value ir_value = CurrentIrValue();
if (ir_value) {
if (ir_value.node != nullptr) {
dynamic_cast<XlaNode*>(GetIrValue().node.get())->ClearSharding();
dynamic_cast<XlaNode*>(ir_value.node.get())->ClearSharding();
}
}
}
Expand Down Expand Up @@ -723,9 +723,14 @@ void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value) {
}

void XLATensor::AssignIrValue(torch::lazy::Value ir_value) const {
// Sharding annotation is not null, if xla_data() is sharded.
ShardingSpecPtr sharding = sharding_spec();
if (ir_value && sharding != nullptr) {
if (sharding != nullptr) {
// Sharded xla_data is accompanied by sharding annotation.
// Use unsynced ir_value or xla_data to hold the annotation.
// TODO(yeounoh): This does not propagate sharding to views.
if (!ir_value) {
ir_value = GetIrValue();
}
dynamic_cast<XlaNode*>(ir_value.node.get())
->SetSharding(sharding->sharding);
}
Expand Down Expand Up @@ -1543,23 +1548,25 @@ std::shared_ptr<XLATensor::Async> XLATensor::ScheduleSyncTensorsGraph(
std::vector<torch::lazy::BackendDataPtr> results;
// Execute replicated if the compiled computation is partitioned.
if (async->cached_computation->is_sharded) {
// TODO(yeounoh) use local devices and verify with the pod execution.
std::vector<std::string> devices =
xla::ComputationClient::Get()->GetAllDevices();
xla::ComputationClient::Get()->GetLocalDevices();
std::vector<std::vector<xla::ComputationClient::DataPtr>>
device_arguments = torch_xla::ShardingUtil::InputHandler(
UnwrapXlaData(async->parameters_data), devices);
xla::ComputationClient::ExecuteReplicatedOptions execute_options;

TF_VLOG(3) << "Executing IR graph hash "
<< torch::lazy::HashToString(hash) << " on all devices.";
<< torch::lazy::HashToString(hash)
<< " on devices: " << absl::StrJoin(devices, ",");
// TODO(jwtan): Remove the WrapXlaData when inherits LazyGraphExecutor.
results = WrapXlaData(xla::ComputationClient::Get()->ExecuteReplicated(
*async->cached_computation->computation->client_computation(),
device_arguments, devices,
execute_options)[0]); // TODO(yeounoh) assumes replicated outputs
TF_VLOG(3) << "Executing IR graph hash "
<< torch::lazy::HashToString(hash)
<< " on all devices, done!";
<< " on devices: " << absl::StrJoin(devices, ",")
<< " done!";
} else {
TF_VLOG(3) << "Executing IR graph hash "
<< torch::lazy::HashToString(hash) << " on device "
Expand Down Expand Up @@ -1912,7 +1919,6 @@ std::shared_ptr<XLATensor::Async> XLATensor::SyncTensorsGraphInternal(
ExtractIRAndPrepareXlaData_(tensors, coll.config, coll.indices, ir_values,
tensor_data_vec);
PostOrderData po_data = RunPostOrder(ir_values, &coll);

coll.hash = torch::lazy::HashCombine(
coll.hash, torch::lazy::Hash(po_data.parameter_sequence));
TF_VLOG(4) << "Parameter sequence graph hash "
Expand All @@ -1922,10 +1928,8 @@ std::shared_ptr<XLATensor::Async> XLATensor::SyncTensorsGraphInternal(
if (async != nullptr) {
return async;
}

CompilationResult compile_result =
Compile(*tensors, devices, coll, &po_data, ir_values);

XLA_VALUE_METRIC("TensorsGraphSize", compile_result.emitted_nodes);
TF_VLOG(5) << "TensorsGraphSize=" << compile_result.emitted_nodes;

Expand Down