diff --git a/test/run_tests.sh b/test/run_tests.sh index 1a24bba4b92..36e2dd0c3ab 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -188,6 +188,7 @@ function run_xla_op_tests { run_test "$CDIR/pjrt/test_mesh_service.py" run_test "$CDIR/spmd/test_xla_sharding.py" run_test "$CDIR/spmd/test_xla_virtual_device.py" + run_test "$CDIR/spmd/test_dynamo_spmd.py" run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY run_test "$CDIR/test_input_output_aliases.py" run_test "$CDIR/test_torch_distributed_xla_backend.py" diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py new file mode 100644 index 00000000000..2be24e1f78a --- /dev/null +++ b/test/spmd/test_dynamo_spmd.py @@ -0,0 +1,58 @@ +import os +import sys + +import torch +import torch.nn as nn +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.experimental.xla_sharding as xs +import torch_xla.debug.metrics as met +import unittest + +import test_xla_sharding_base + + +class SimpleLinear(nn.Module): + + def __init__(self): + super(SimpleLinear, self).__init__() + self.fc1 = nn.Linear(128, 128) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(128, 1) + # Add an additional 1x1 layer at the end to ensure the final layer + # is not sharded. + self.fc3 = nn.Linear(1, 1) + + def forward(self, x): + y = self.relu(self.fc1(x)) + z = self.fc2(y) + return self.fc3(z) + + +class DynamoSpmdInferenceTest(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + os.environ["XLA_USE_SPMD"] = "1" + super().setUpClass() + + def test_dynamo_spmd_basic(self): + device = xm.xla_device() + linear = SimpleLinear().to(device) + linear.eval() + xla_x = torch.randn(1, 128, device=device) + xs.mark_sharding(linear.fc2.weight, self._get_mesh((1, self.n_devices)), + (1, 0)) + xla_res = linear(xla_x) + xm.mark_step() + + dynamo_linear = torch.compile(linear, backend="torchxla_trace_once") + dynamo_res = dynamo_linear(xla_x) + torch.allclose(xla_res.cpu(), dynamo_res.cpu()) + # TODO(JackCaoG): add counter checks after ExecuteReplicated also creates + # a ExecuteMetric. + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index 892e5bbf8e5..3c7a5dc6fdc 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -305,6 +305,11 @@ torch::lazy::BackendDevice AtenDeviceToXlaDevice(const c10::Device& device) { } c10::Device XlaDeviceToAtenDevice(const torch::lazy::BackendDevice& device) { + // TODO(yeounoh) until we expose SPMD virtual device to the frontend, this + // will just be `XLA:0`. + if (device.type() == (int8_t)XlaDeviceType::SPMD) { + return c10::Device(at::kXLA, (size_t)0); + } return c10::Device(at::kXLA, AtenXlaDeviceMapper::Get()->GetDeviceOrdinal(device)); } diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 47709a83b60..e1bebb1cb19 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1728,7 +1728,10 @@ void InitXlaModuleBindings(py::module m) { -> std::vector { XLA_CHECK(hash_str.size() == sizeof(torch::lazy::hash_t)); torch::lazy::hash_t hash = *(torch::lazy::hash_t*)(hash_str.c_str()); - torch::lazy::BackendDevice device = torch_xla::GetCurrentDevice(); + // Device will be Virtual device if SPMD is enabled. + torch::lazy::BackendDevice device = + ShardingUtil::UseVirtualDevice() ? ParseDeviceString("SPMD:0") + : torch_xla::GetCurrentDevice(); auto results = XLAGraphExecutor::Get()->ExecuteComputationWithBarrier( hash, graph_inputs, device); std::vector retlist; diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 9d7cab34aed..56056f7e769 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -574,6 +574,9 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( MaybeDumpGraph("dynamo", hash); auto cachedComputation = XLAGraphExecutor::Get()->GetComputationCache()->Get(hash); + TF_VLOG(5) << "Cached computation (hash: " << torch::lazy::HashToString(hash) + << ") is_sharded=" << cachedComputation->is_sharded << std::endl; + // TODO implement a fallback mechanism, or make sure those entries // never get kicked out XLA_CHECK(cachedComputation) @@ -590,6 +593,16 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( torch::lazy::BackendDataPtr handle = WrapXlaData(xla::ComputationClient::Get()->CreateDataPlaceholder( device.toString(), std::move(shape))); + // If SPMD is enabled, we assume all output will be sharded or replicated + // and wrapped inside PjRtShardedData handle. + if (ShardingUtil::UseVirtualDevice()) { + XLATensor::ShardingSpecPtr sharding = + std::make_shared( + xla::HloSharding::Replicate().ToProto(), shape); + handle = WrapXlaData(xla::ComputationClient::Get()->WrapDataShards( + {UnwrapXlaData(handle)}, GetVirtualDevice().toString(), + sharding->shape.value(), sharding->sharding)); + } placeholders.push_back(handle); } @@ -608,6 +621,9 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( if (auto xla_tensor_ptr = bridge::TryGetXlaTensor(ivalue.toTensor())) { dataptr = xla_tensor_ptr->GetXlaData(); } else { + XLA_CHECK(device.type() != (int8_t)XlaDeviceType::SPMD) + << "SPMD device data should already be on the XLA backend " + "(XLATensor)."; dataptr = torch_xla::TensorToXlaData(ivalue.toTensor(), device); } @@ -619,21 +635,65 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( std::shared_ptr async = std::make_shared( &coll, std::move(arguments), placeholders, std::move(cachedComputation)); - auto syncfn = [async, hash]() { - TF_VLOG(3) << "Executing Dynamo IR graph hash " - << torch::lazy::HashToString(hash) << " on device " - << async->device << " ..."; - std::vector results = - torch::lazy::getBackend()->ExecuteComputation( + // TODO(yeounoh) supply proper sharding specs for sharded results. + std::vector sharding_specs(placeholders.size()); + + auto syncfn = [async, hash, sharding_specs]() { + try { + TF_VLOG(3) << "Executing Dynamo IR graph hash " + << torch::lazy::HashToString(hash) << " on device " + << async->device << " ..."; + + std::vector results; + if (async->cached_computation->is_sharded) { + std::vector devices = + xla::ComputationClient::Get()->GetLocalDevices(); + std::vector> + device_arguments = ShardingUtil::InputHandler( + UnwrapXlaData(async->parameters_data), devices); + xla::ComputationClient::ExecuteReplicatedOptions execute_options; + // OutputHandler creates sharded data for sharded + // tensor results. Both sharded and unsharded results should be + // "Assign"ed to the corresponding data placeholders. + std::vector outputs = + ShardingUtil::OutputHandler( + xla::ComputationClient::Get()->ExecuteReplicated( + *async->cached_computation->computation + ->client_computation(), + device_arguments, devices, execute_options), + sharding_specs); + results = WrapXlaData(outputs); + TF_VLOG(3) << "Executing Dynamo IR sharded graph hash " + << torch::lazy::HashToString(hash) << " on devices " + << absl::StrJoin(devices, ",") << " done!"; + } else { + results = torch::lazy::getBackend()->ExecuteComputation( async->cached_computation->computation, async->parameters_data, async->device); - TF_VLOG(3) << "Executing Dynamo IR graph hash " - << torch::lazy::HashToString(hash) << " on device " - << async->device << " done!"; - // Updating placeholder with actual output handle. - for (size_t i = 0; i < results.size(); ++i) { - XLA_CHECK(async->tensors_data[i] != nullptr); - async->tensors_data[i]->Assign(*results[i]); + TF_VLOG(3) << "Executing Dynamo IR graph hash " + << torch::lazy::HashToString(hash) << " on device " + << async->device << " done!"; + } + + // Updating placeholder with actual output handle. + for (size_t i = 0; i < results.size(); ++i) { + XLA_CHECK(async->tensors_data[i] != nullptr); + async->tensors_data[i]->Assign(*results[i]); + } + } catch (...) { + // There are two paths of discovery of an exception happening on an + // asynchronous task. One happens if the creator of the asynchronous task + // explicitly waits for completion, in which case the exception will be + // thrown from the Wait() API. Re-throwing the exception below makes sure + // this will be captured by the completer function created below, and + // surfaced by the Wait() API. But we also need to surface the exception + // even in case the caller does not wait, and that is accomplished by + // setting the unlockers status. In that case the exception will be + // surfaced when the user tries to acquire the device locks the next time. + for (auto& unlocker : async->unlocker) { + unlocker.SetStatus(std::current_exception()); + } + throw; } };