From 2417643045425d268d0dccaa5b6d2c920459337b Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Fri, 12 May 2023 00:00:41 +0000 Subject: [PATCH 1/5] initial code change --- torch_xla/csrc/aten_xla_bridge.cpp | 5 +++ torch_xla/csrc/xla_graph_executor.cpp | 48 ++++++++++++++++++++++----- 2 files changed, 45 insertions(+), 8 deletions(-) diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index 892e5bbf8e5..3c7a5dc6fdc 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -305,6 +305,11 @@ torch::lazy::BackendDevice AtenDeviceToXlaDevice(const c10::Device& device) { } c10::Device XlaDeviceToAtenDevice(const torch::lazy::BackendDevice& device) { + // TODO(yeounoh) until we expose SPMD virtual device to the frontend, this + // will just be `XLA:0`. + if (device.type() == (int8_t)XlaDeviceType::SPMD) { + return c10::Device(at::kXLA, (size_t)0); + } return c10::Device(at::kXLA, AtenXlaDeviceMapper::Get()->GetDeviceOrdinal(device)); } diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 9d7cab34aed..7ceb4712285 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -574,6 +574,9 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( MaybeDumpGraph("dynamo", hash); auto cachedComputation = XLAGraphExecutor::Get()->GetComputationCache()->Get(hash); + TF_VLOG(5) << "Cached computation (hash: " << torch::lazy::HashToString(hash) + << ") is_sharded=" << cachedComputation->is_sharded << std::endl; + // TODO implement a fallback mechanism, or make sure those entries // never get kicked out XLA_CHECK(cachedComputation) @@ -592,6 +595,8 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( device.toString(), std::move(shape))); placeholders.push_back(handle); } + // TODO(yeounoh) supply proper sharding specs for sharded results. + std::vector sharding_specs(placeholders.size()); SyncTensorCollection coll; coll.device = device; @@ -608,6 +613,9 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( if (auto xla_tensor_ptr = bridge::TryGetXlaTensor(ivalue.toTensor())) { dataptr = xla_tensor_ptr->GetXlaData(); } else { + XLA_CHECK(device.type() != (int8_t)XlaDeviceType::SPMD) + << "SPMD device data should already be on the XLA backend " + "(XLATensor)."; dataptr = torch_xla::TensorToXlaData(ivalue.toTensor(), device); } @@ -619,17 +627,41 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( std::shared_ptr async = std::make_shared( &coll, std::move(arguments), placeholders, std::move(cachedComputation)); - auto syncfn = [async, hash]() { + auto syncfn = [async, hash, sharding_specs]() { TF_VLOG(3) << "Executing Dynamo IR graph hash " << torch::lazy::HashToString(hash) << " on device " << async->device << " ..."; - std::vector results = - torch::lazy::getBackend()->ExecuteComputation( - async->cached_computation->computation, async->parameters_data, - async->device); - TF_VLOG(3) << "Executing Dynamo IR graph hash " - << torch::lazy::HashToString(hash) << " on device " - << async->device << " done!"; + + std::vector results; + if (async->cached_computation->is_sharded) { + std::vector devices = + xla::ComputationClient::Get()->GetLocalDevices(); + std::vector> + device_arguments = ShardingUtil::InputHandler( + UnwrapXlaData(async->parameters_data), devices); + xla::ComputationClient::ExecuteReplicatedOptions execute_options; + // OutputHandler creates sharded data for sharded + // tensor results. Both sharded and unsharded results should be + // "Assign"ed to the corresponding data placeholders. + std::vector outputs = + ShardingUtil::OutputHandler( + xla::ComputationClient::Get()->ExecuteReplicated( + *async->cached_computation->computation->client_computation(), + device_arguments, devices, execute_options), + sharding_specs); + results = WrapXlaData(outputs); + TF_VLOG(3) << "Executing Dynamo IR graph hash " + << torch::lazy::HashToString(hash) << " on devices " + << absl::StrJoin(devices, ",") << " done!"; + } else { + results = torch::lazy::getBackend()->ExecuteComputation( + async->cached_computation->computation, async->parameters_data, + async->device); + TF_VLOG(3) << "Executing Dynamo IR graph hash " + << torch::lazy::HashToString(hash) << " on device " + << async->device << " done!"; + } + // Updating placeholder with actual output handle. for (size_t i = 0; i < results.size(); ++i) { XLA_CHECK(async->tensors_data[i] != nullptr); From 81850656076d4f5a95f35041a0b3849f7b57da22 Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Fri, 12 May 2023 00:57:21 +0000 Subject: [PATCH 2/5] Add simple test, which currently failed --- test/spmd/test_dynamo_spmd.py | 56 +++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 test/spmd/test_dynamo_spmd.py diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py new file mode 100644 index 00000000000..8023875fd6b --- /dev/null +++ b/test/spmd/test_dynamo_spmd.py @@ -0,0 +1,56 @@ +import os +import sys + +import torch +import torch.nn as nn +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.experimental.xla_sharding as xs +import torch_xla.debug.metrics as met +import unittest + +import test_xla_sharding_base + + +class SimpleLinear(nn.Module): + + def __init__(self): + super(SimpleLinear, self).__init__() + self.fc1 = nn.Linear(128, 128) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(128, 1) + # Add an additional 1x1 layer at the end to ensure the final layer + # is not sharded. + self.fc3 = nn.Linear(1, 1) + + def forward(self, x): + y = self.relu(self.fc1(x)) + z = self.fc2(y) + return self.fc3(z) + + +class DynamoSpmdInferenceTest(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + os.environ["XLA_USE_SPMD"] = "1" + super().setUpClass() + + def test_dynamo_spmd_basic(self): + device = xm.xla_device() + linear = SimpleLinear().to(device) + linear.eval() + xla_x = torch.randn(1, 128, device=device) + xs.mark_sharding(linear.fc2.weight, self._get_mesh((1, self.n_devices)), + (1, 0)) + xla_res = linear(xla_x) + xm.mark_step() + + # TODO: this currently failed with `Check failed: handle->HasValue() ` + dynamo_linear = torch.compile(linear, backend="torchxla_trace_once") + dynamo_res = dynamo_linear(xla_x) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) \ No newline at end of file From 89113ca07505fa1ae8559a777826e36259917488 Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Tue, 16 May 2023 01:44:33 +0000 Subject: [PATCH 3/5] Add try catch to dynamo, use SPMD device in dynamo, create sharded placeholder if SPMD is enabled --- test/spmd/test_dynamo_spmd.py | 2 +- torch_xla/csrc/init_python_bindings.cpp | 5 +- torch_xla/csrc/xla_graph_executor.cpp | 97 ++++++++++++++++--------- 3 files changed, 67 insertions(+), 37 deletions(-) diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index 8023875fd6b..4049f81b901 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -46,9 +46,9 @@ def test_dynamo_spmd_basic(self): xla_res = linear(xla_x) xm.mark_step() - # TODO: this currently failed with `Check failed: handle->HasValue() ` dynamo_linear = torch.compile(linear, backend="torchxla_trace_once") dynamo_res = dynamo_linear(xla_x) + torch.allclose(xla_res.cpu(), dynamo_res.cpu()) if __name__ == '__main__': diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 47709a83b60..e1bebb1cb19 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1728,7 +1728,10 @@ void InitXlaModuleBindings(py::module m) { -> std::vector { XLA_CHECK(hash_str.size() == sizeof(torch::lazy::hash_t)); torch::lazy::hash_t hash = *(torch::lazy::hash_t*)(hash_str.c_str()); - torch::lazy::BackendDevice device = torch_xla::GetCurrentDevice(); + // Device will be Virtual device if SPMD is enabled. + torch::lazy::BackendDevice device = + ShardingUtil::UseVirtualDevice() ? ParseDeviceString("SPMD:0") + : torch_xla::GetCurrentDevice(); auto results = XLAGraphExecutor::Get()->ExecuteComputationWithBarrier( hash, graph_inputs, device); std::vector retlist; diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 7ceb4712285..7badf773ae6 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -593,8 +593,18 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( torch::lazy::BackendDataPtr handle = WrapXlaData(xla::ComputationClient::Get()->CreateDataPlaceholder( device.toString(), std::move(shape))); + // if SPMD is enabled, we assume all output will be replicated + if (ShardingUtil::UseVirtualDevice()) { + XLATensor::ShardingSpecPtr sharding = + std::make_shared( + xla::HloSharding::Replicate().ToProto(), shape); + handle = WrapXlaData(xla::ComputationClient::Get()->WrapDataShards( + {UnwrapXlaData(handle)}, GetVirtualDevice().toString(), + sharding->shape.value(), sharding->sharding)); + } placeholders.push_back(handle); } + // TODO(yeounoh) supply proper sharding specs for sharded results. std::vector sharding_specs(placeholders.size()); @@ -628,44 +638,61 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( &coll, std::move(arguments), placeholders, std::move(cachedComputation)); auto syncfn = [async, hash, sharding_specs]() { - TF_VLOG(3) << "Executing Dynamo IR graph hash " - << torch::lazy::HashToString(hash) << " on device " - << async->device << " ..."; - - std::vector results; - if (async->cached_computation->is_sharded) { - std::vector devices = - xla::ComputationClient::Get()->GetLocalDevices(); - std::vector> - device_arguments = ShardingUtil::InputHandler( - UnwrapXlaData(async->parameters_data), devices); - xla::ComputationClient::ExecuteReplicatedOptions execute_options; - // OutputHandler creates sharded data for sharded - // tensor results. Both sharded and unsharded results should be - // "Assign"ed to the corresponding data placeholders. - std::vector outputs = - ShardingUtil::OutputHandler( - xla::ComputationClient::Get()->ExecuteReplicated( - *async->cached_computation->computation->client_computation(), - device_arguments, devices, execute_options), - sharding_specs); - results = WrapXlaData(outputs); - TF_VLOG(3) << "Executing Dynamo IR graph hash " - << torch::lazy::HashToString(hash) << " on devices " - << absl::StrJoin(devices, ",") << " done!"; - } else { - results = torch::lazy::getBackend()->ExecuteComputation( - async->cached_computation->computation, async->parameters_data, - async->device); + try { TF_VLOG(3) << "Executing Dynamo IR graph hash " << torch::lazy::HashToString(hash) << " on device " - << async->device << " done!"; - } + << async->device << " ..."; - // Updating placeholder with actual output handle. - for (size_t i = 0; i < results.size(); ++i) { - XLA_CHECK(async->tensors_data[i] != nullptr); - async->tensors_data[i]->Assign(*results[i]); + std::vector results; + if (async->cached_computation->is_sharded) { + std::vector devices = + xla::ComputationClient::Get()->GetLocalDevices(); + std::vector> + device_arguments = ShardingUtil::InputHandler( + UnwrapXlaData(async->parameters_data), devices); + xla::ComputationClient::ExecuteReplicatedOptions execute_options; + // OutputHandler creates sharded data for sharded + // tensor results. Both sharded and unsharded results should be + // "Assign"ed to the corresponding data placeholders. + std::vector outputs = + ShardingUtil::OutputHandler( + xla::ComputationClient::Get()->ExecuteReplicated( + *async->cached_computation->computation + ->client_computation(), + device_arguments, devices, execute_options), + sharding_specs); + results = WrapXlaData(outputs); + TF_VLOG(3) << "Executing Dynamo IR sharded graph hash " + << torch::lazy::HashToString(hash) << " on devices " + << absl::StrJoin(devices, ",") << " done!"; + } else { + results = torch::lazy::getBackend()->ExecuteComputation( + async->cached_computation->computation, async->parameters_data, + async->device); + TF_VLOG(3) << "Executing Dynamo IR graph hash " + << torch::lazy::HashToString(hash) << " on device " + << async->device << " done!"; + } + + // Updating placeholder with actual output handle. + for (size_t i = 0; i < results.size(); ++i) { + XLA_CHECK(async->tensors_data[i] != nullptr); + async->tensors_data[i]->Assign(*results[i]); + } + } catch (...) { + // There are two paths of discovery of an exception happening on an + // asynchronous task. One happens if the creator of the asynchronous task + // explicitly waits for completion, in which case the exception will be + // thrown from the Wait() API. Re-throwing the exception below makes sure + // this will be captured by the completer function created below, and + // surfaced by the Wait() API. But we also need to surface the exception + // even in case the caller does not wait, and that is accomplished by + // setting the unlockers status. In that case the exception will be + // surfaced when the user tries to acquire the device locks the next time. + for (auto& unlocker : async->unlocker) { + unlocker.SetStatus(std::current_exception()); + } + throw; } }; From e47acfa892d088e4c8120e17b871d7cb8a288921 Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Wed, 17 May 2023 00:02:06 +0000 Subject: [PATCH 4/5] fix review comments --- test/run_tests.sh | 1 + test/spmd/test_dynamo_spmd.py | 4 +++- torch_xla/csrc/xla_graph_executor.cpp | 6 +++--- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/test/run_tests.sh b/test/run_tests.sh index 1a24bba4b92..36e2dd0c3ab 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -188,6 +188,7 @@ function run_xla_op_tests { run_test "$CDIR/pjrt/test_mesh_service.py" run_test "$CDIR/spmd/test_xla_sharding.py" run_test "$CDIR/spmd/test_xla_virtual_device.py" + run_test "$CDIR/spmd/test_dynamo_spmd.py" run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY run_test "$CDIR/test_input_output_aliases.py" run_test "$CDIR/test_torch_distributed_xla_backend.py" diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index 4049f81b901..2be24e1f78a 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -49,8 +49,10 @@ def test_dynamo_spmd_basic(self): dynamo_linear = torch.compile(linear, backend="torchxla_trace_once") dynamo_res = dynamo_linear(xla_x) torch.allclose(xla_res.cpu(), dynamo_res.cpu()) + # TODO(JackCaoG): add counter checks after ExecuteReplicated also creates + # a ExecuteMetric. if __name__ == '__main__': test = unittest.main() - sys.exit(0 if test.result.wasSuccessful() else 1) \ No newline at end of file + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 7badf773ae6..b420fc5914b 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -605,9 +605,6 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( placeholders.push_back(handle); } - // TODO(yeounoh) supply proper sharding specs for sharded results. - std::vector sharding_specs(placeholders.size()); - SyncTensorCollection coll; coll.device = device; coll.unlocker = DeviceLockerArena::Get()->LockDevices({device}); @@ -637,6 +634,9 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( std::shared_ptr async = std::make_shared( &coll, std::move(arguments), placeholders, std::move(cachedComputation)); + // TODO(yeounoh) supply proper sharding specs for sharded results. + std::vector sharding_specs(placeholders.size()); + auto syncfn = [async, hash, sharding_specs]() { try { TF_VLOG(3) << "Executing Dynamo IR graph hash " From 2a8cc73459a38f252bc97464e644b52122437791 Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Wed, 17 May 2023 22:44:25 +0000 Subject: [PATCH 5/5] change comments --- torch_xla/csrc/xla_graph_executor.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index b420fc5914b..56056f7e769 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -593,7 +593,8 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( torch::lazy::BackendDataPtr handle = WrapXlaData(xla::ComputationClient::Get()->CreateDataPlaceholder( device.toString(), std::move(shape))); - // if SPMD is enabled, we assume all output will be replicated + // If SPMD is enabled, we assume all output will be sharded or replicated + // and wrapped inside PjRtShardedData handle. if (ShardingUtil::UseVirtualDevice()) { XLATensor::ShardingSpecPtr sharding = std::make_shared(