diff --git a/third_party/xla_client/multi_wait.cc b/third_party/xla_client/multi_wait.cc index 3adbcd0814f0..d3cb7b834771 100644 --- a/third_party/xla_client/multi_wait.cc +++ b/third_party/xla_client/multi_wait.cc @@ -41,7 +41,7 @@ std::function MultiWait::Completer(std::function func) { func(); Done(); } catch (const std::exception& ex) { - Done(tensorflow::errors::Aborted(ex.what())); + Done(tensorflow::errors::Internal(ex.what())); } }; return completer; diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index f31e5c746687..515c8605415e 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -12,7 +12,6 @@ #include "tensorflow/compiler/xla/xla_client/cache.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" #include "tensorflow/compiler/xla/xla_client/metrics.h" -#include "tensorflow/compiler/xla/xla_client/multi_wait.h" #include "tensorflow/compiler/xla/xla_client/sys_util.h" #include "tensorflow/compiler/xla/xla_client/thread_pool.h" #include "tensorflow/compiler/xla/xla_client/unique.h" @@ -356,8 +355,7 @@ xla::ComputationClient::DataPtr XLATensor::CurrentXlaData() const { } const ir::ops::DeviceData* device_data = dynamic_cast(ir_value.node.get()); - if (device_data != nullptr && - device_data->data().get() == data()->xla_data.get()) { + if (device_data != nullptr && device_data->data() == data()->xla_data) { return data()->xla_data; } } @@ -370,7 +368,7 @@ std::string XLATensor::DumpGraphNodeComputation() const { if (ir_value) { ir::LoweringContext lowering_ctx("DumpGraphNodeComputation"); xla::XlaOp root = lowering_ctx.GetOutputOp(ir_value); - auto computation = ConsumeValue(lowering_ctx.Build(root)); + xla::XlaComputation computation = ConsumeValue(lowering_ctx.Build(root)); hlo_text = ConsumeValue(xla::xrt_util::GetComputationHloText(computation)); } return hlo_text; @@ -408,7 +406,7 @@ void XLATensor::TryLimitGraphSize() { static const size_t kCheckFrequency = xla::sys_util::GetEnvInt("TRIM_GRAPH_CHECK_FREQUENCY", 100); static const size_t kMaxPendingGraphSize = - xla::sys_util::GetEnvInt("TRIM_GRAPH_SIZE", 5000); + xla::sys_util::GetEnvInt("TRIM_GRAPH_SIZE", 10000); static std::atomic counter(1); if (data()->ir_value && counter.fetch_add(1) % kCheckFrequency == 0) { size_t graph_size = ir::Util::GetGraphSize({data()->ir_value.node.get()}); @@ -555,26 +553,62 @@ std::vector XLATensor::GetLiveTensors() { return TensorsArena::Get()->GetTensors(); } -std::vector XLATensor::GetTensors( - std::vector* tensors, const std::vector* writeable) { - // TODO(dlibenzi): We do apply/compute and then fetch. Changing the API to - // support getting handles and data might save a few pennies here. - ApplyPendingGraph(tensors); - +std::vector XLATensor::GatherTensorsXlaData( + const std::vector& tensors, std::shared_ptr async) { std::vector tensors_data; - for (auto& tensor : *tensors) { - if (!tensor.CurrentTensorData()) { - tensors_data.push_back(tensor.GetXlaData()); + if (async != nullptr) { + size_t indices_index = 0; + for (size_t i = 0; i < tensors.size(); ++i) { + if (!tensors[i].CurrentTensorData()) { + if (indices_index < async->indices.size() && + i == async->indices[indices_index]) { + // If we are at the current index (it means that the tensor at index + // 'i' had an IR node to sync, use the XLA data held within the Async + // object. + tensors_data.push_back(async->tensors_data[indices_index]); + } else { + xla::ComputationClient::DataPtr xla_data = + tensors[i].CurrentXlaData(); + XLA_CHECK(xla_data != nullptr); + tensors_data.push_back(std::move(xla_data)); + } + } + if (indices_index < async->indices.size() && + i == async->indices[indices_index]) { + ++indices_index; + } + } + } else { + // If we are here, async is nullptr, which means that none of the input + // tensors had an IR node to sync. This means that they either have + // at::Tensor data, or XLA data. + for (auto& tensor : tensors) { + if (!tensor.CurrentTensorData()) { + xla::ComputationClient::DataPtr xla_data = tensor.CurrentXlaData(); + XLA_CHECK(xla_data != nullptr); + tensors_data.push_back(std::move(xla_data)); + } } } + return tensors_data; +} + +std::vector XLATensor::GetTensors( + std::vector* tensors, const std::vector* writeable) { + SyncTensorsConfig config; + auto async = SyncTensorsGraphInternal(tensors, config); + if (async != nullptr) { + XLA_CHECK_OK(async->mwait.Wait()); + } + std::vector tensors_data = + GatherTensorsXlaData(*tensors, async); std::vector literals = xla::ComputationClient::Get()->TransferFromServer(tensors_data); std::vector results; size_t literals_index = 0; results.reserve(tensors->size()); for (size_t i = 0; i < tensors->size(); ++i) { - const c10::optional& tensor_data = - (*tensors)[i].CurrentTensorData(); + c10::optional tensor_data = (*tensors)[i].CurrentTensorData(); if (tensor_data) { results.push_back(*tensor_data); } else { @@ -672,7 +706,7 @@ void XLATensor::ApplyPendingGraph() { xla::XlaOp root = lowering_ctx.GetOutputOp(ir_value); xla::XlaComputation computation = ConsumeValue(lowering_ctx.Build(root)); xla::Shape output_shape = shape().get(); - const xla::Shape computation_shape = + xla::Shape computation_shape = ConsumeValue(computation.GetProgramShape()).result(); // Some in-place operations (e.g. squeeze) can change the shape. if (!xla::ShapeUtil::Compatible(computation_shape, output_shape)) { @@ -744,12 +778,13 @@ XLATensor::SyncTensorCollection XLATensor::CollectSyncTensors( return coll; } -bool XLATensor::TryRunCachedSync(std::vector* tensors, - SyncTensorCollection* coll) { +std::shared_ptr XLATensor::TryRunCachedSync( + std::vector* tensors, const SyncTensorsConfig& config, + SyncTensorCollection* coll) { ComputationCache::TypePtr cached_computation = GetComputationCache()->Get(coll->hash); if (cached_computation == nullptr) { - return false; + return nullptr; } xla::xla_util::Unique unique_device; @@ -769,14 +804,13 @@ bool XLATensor::TryRunCachedSync(std::vector* tensors, } } if (cached_computation->num_parameters != parameters_data.size()) { - return false; + return nullptr; } XLA_COUNTER("CachedSyncTensors", 1); - ScheduleSyncTensorsGraph(tensors, coll, std::move(parameters_data), - unique_device->ToString(), - std::move(cached_computation)); - return true; + return ScheduleSyncTensorsGraph( + tensors, config, coll, std::move(parameters_data), + unique_device->ToString(), std::move(cached_computation)); } XLATensor::ComputationCache* XLATensor::GetComputationCache() { @@ -791,45 +825,26 @@ XLATensor::ApplyContextCache* XLATensor::GetApplyContextCache() { return cache; } -void XLATensor::ScheduleSyncTensorsGraph( - std::vector* tensors, SyncTensorCollection* coll, +std::shared_ptr XLATensor::ScheduleSyncTensorsGraph( + std::vector* tensors, const SyncTensorsConfig& config, + SyncTensorCollection* coll, std::vector parameters_data, std::string device, ComputationCache::TypePtr cached_computation) { - // The xla::util::Cleanup is not copiable, so even though we can create a - // lambda my moving it inside the capture, channeling thorugh an - // std::function<> requires captures to be copiable (even though nothing would - // actually be copied). - struct Async { - Async(SyncTensorCollection* coll, - std::vector parameters_data, - std::string device, ComputationCache::TypePtr cached_computation) - : unlocker(std::move(coll->unlocker)), - parameters_data(std::move(parameters_data)), - device(std::move(device)), - cached_computation(std::move(cached_computation)) { - tensors_data.reserve(coll->indices.size()); - } - - std::vector unlocker; - std::vector parameters_data; - std::string device; - ComputationCache::TypePtr cached_computation; - std::vector tensors_data; - }; std::shared_ptr async = std::make_shared( coll, std::move(parameters_data), device, std::move(cached_computation)); - for (auto index : coll->indices) { - // The purpose of a tensor sync operation is to truncate the IR graph and - // materialize device data in place of IR graph, on selected tensors. - // But since operation will complete asynchronously, if a tensor does not - // already have device data, we need to install a placeholder. Since at this - // point we hold a lock on the device where the tensors reside (locks held - // within the coll structure, and moved into the async variable), any other - // operation trying to access the tensor's device data will have to wait - // until the asynchronous operation completes. + for (auto index : async->indices) { + // If the config.force_xla_data flag is true, the purpose of this tensor + // sync operation is to truncate the IR graph and materialize device data in + // place of IR graph, on selected tensors. But since operation will complete + // asynchronously, if a tensor does not already have device data, we need to + // install a placeholder. Since at this point we hold a lock on the device + // where the tensors reside (locks held within the coll structure, and moved + // into the async variable), any other operation trying to access the + // tensor's device data will have to wait until the asynchronous operation + // completes. xla::ComputationClient::DataPtr xla_data = (*tensors)[index].CurrentXlaData(); - if (xla_data == nullptr) { + if (xla_data == nullptr && config.force_xla_data) { xla_data = xla::ComputationClient::Get()->CreateDataPlaceholder( device, (*tensors)[index].shape()); (*tensors)[index].SetXlaData(xla_data); @@ -837,65 +852,80 @@ void XLATensor::ScheduleSyncTensorsGraph( async->tensors_data.emplace_back(std::move(xla_data)); } - auto syncfn = [async = std::move(async)]() { + auto syncfn = [async]() { xla::ComputationClient::ExecuteComputationOptions options; try { auto results = xla::ComputationClient::Get()->ExecuteComputation( *async->cached_computation->computation, async->parameters_data, async->device, options); for (size_t i = 0; i < results.size(); ++i) { - async->tensors_data[i]->Swap(results[i].get()); + if (async->tensors_data[i] != nullptr) { + async->tensors_data[i]->Swap(results[i].get()); + } else { + async->tensors_data[i] = std::move(results[i]); + } } } catch (const std::exception& ex) { - xla::Status status = tensorflow::errors::Aborted(ex.what()); + xla::Status status = tensorflow::errors::Internal(ex.what()); for (auto& unlocker : async->unlocker) { unlocker.SetStatus(status); } } }; - xla::xla_env::ScheduleIoClosure(std::move(syncfn)); + xla::xla_env::ScheduleIoClosure(async->mwait.Completer(std::move(syncfn))); + return async; } void XLATensor::SyncTensorsGraph(std::vector* tensors) { - SyncTensorCollection coll = CollectSyncTensors(*tensors); - if (!coll.indices.empty() && !TryRunCachedSync(tensors, &coll)) { - XLA_COUNTER("UncachedSyncTensors", 1); + SyncTensorsConfig config; + config.force_xla_data = true; + SyncTensorsGraphInternal(tensors, config); +} - xla::xla_util::Unique unique_device; - ir::LoweringContext lowering_ctx("SyncTensorsGraph"); - for (auto index : coll.indices) { - ir::Value ir_value = (*tensors)[index].CurrentIrValue(); - xla::XlaOp root = lowering_ctx.GetOutputOp(ir_value); - lowering_ctx.AddResult(root); - unique_device.set((*tensors)[index].GetDevice()); - } +std::shared_ptr XLATensor::SyncTensorsGraphInternal( + std::vector* tensors, const SyncTensorsConfig& config) { + SyncTensorCollection coll = CollectSyncTensors(*tensors); + if (coll.indices.empty()) { + return nullptr; + } + std::shared_ptr async = TryRunCachedSync(tensors, config, &coll); + if (async != nullptr) { + return async; + } + XLA_COUNTER("UncachedSyncTensors", 1); - xla::XlaComputation computation = ConsumeValue(lowering_ctx.Build()); - xla::ProgramShape program_shape = - ConsumeValue(computation.GetProgramShape()); - xla::Shape shape = MakeShapeWithDeviceLayout(program_shape.result(), - unique_device->hw_type); - - std::vector instances; - instances.push_back({std::move(computation), - GetCompilationDevices(unique_device->ToString()), - &shape}); - - std::vector> - computations = - xla::ComputationClient::Get()->Compile(std::move(instances)); - std::vector parameters_data = - lowering_ctx.GetParametersData(); - ComputationCache::TypePtr cached_computation = GetComputationCache()->Add( - coll.hash, - std::make_shared(std::move(computations.front()), - parameters_data.size())); - - ScheduleSyncTensorsGraph(tensors, &coll, std::move(parameters_data), - unique_device->ToString(), - std::move(cached_computation)); + xla::xla_util::Unique unique_device; + ir::LoweringContext lowering_ctx("SyncTensorsGraph"); + for (auto index : coll.indices) { + ir::Value ir_value = (*tensors)[index].CurrentIrValue(); + xla::XlaOp root = lowering_ctx.GetOutputOp(ir_value); + lowering_ctx.AddResult(root); + unique_device.set((*tensors)[index].GetDevice()); } + + xla::XlaComputation computation = ConsumeValue(lowering_ctx.Build()); + xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); + xla::Shape shape = + MakeShapeWithDeviceLayout(program_shape.result(), unique_device->hw_type); + + std::vector instances; + instances.push_back({std::move(computation), + GetCompilationDevices(unique_device->ToString()), + &shape}); + + std::vector> + computations = + xla::ComputationClient::Get()->Compile(std::move(instances)); + std::vector parameters_data = + lowering_ctx.GetParametersData(); + ComputationCache::TypePtr cached_computation = GetComputationCache()->Add( + coll.hash, std::make_shared( + std::move(computations.front()), parameters_data.size())); + + return ScheduleSyncTensorsGraph( + tensors, config, &coll, std::move(parameters_data), + unique_device->ToString(), std::move(cached_computation)); } XLATensor::SyncTensorCollection XLATensor::CollectApplyGraphTensors( @@ -936,7 +966,7 @@ bool XLATensor::RunCachedApply(std::vector* tensors, for (auto uid : device_input_mapping) { auto it = uid_index_map.find(uid); if (it != uid_index_map.end()) { - const xla::ComputationClient::DataPtr& xla_data = + xla::ComputationClient::DataPtr xla_data = (*tensors)[it->second].data()->xla_data; if (xla_data == nullptr) { // If we do not find real device data (we have a cached graph @@ -945,7 +975,7 @@ bool XLATensor::RunCachedApply(std::vector* tensors, XLA_COUNTER("NoTensorDataForUid", 1); return false; } - device_parameters.push_back(xla_data); + device_parameters.push_back(std::move(xla_data)); } else { // If we have not found the unique ID of the parameter which is // supposed to feed data to the computation, the pending graph context @@ -1131,12 +1161,11 @@ void XLATensor::ApplyPendingGraph(std::vector* tensors) { context_iterator->second.index_mapping); ++context_iterator; } - if (unknown_params == 0) { - apply_context = std::make_shared(); - *apply_context = {std::move(computations), std::move(uid_order), - std::move(input_mapping), std::move(index_mapping), - std::move(devices)}; + apply_context = std::make_shared( + std::move(computations), std::move(uid_order), + std::move(input_mapping), std::move(index_mapping), + std::move(devices)); GetApplyContextCache()->Add(hash, std::move(apply_context)); } } diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 065579e9367d..53c00c3fc750 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -8,6 +8,7 @@ #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_client/cache.h" #include "tensorflow/compiler/xla/xla_client/computation_client.h" +#include "tensorflow/compiler/xla/xla_client/multi_wait.h" #include "tensorflow/compiler/xla/xla_client/util.h" #include "torch/csrc/autograd/variable.h" #include "torch_xla/csrc/device.h" @@ -121,7 +122,7 @@ class XLATensor { // Retrieves the PyTorch tensors behind the XLA tensors. If the writeable // vector is not nullptr, it must be the same size as tensors, and the // corresponding bool tells whether the ATEN tensor to be retrieved should the - // a writeable copy. + // a writeable copy. All the tensors must be on the same device. static std::vector GetTensors(std::vector* tensors, const std::vector* writeable); @@ -819,10 +820,52 @@ class XLATensor { using ComputationCache = xla::util::Cache; + struct Async { + Async(SyncTensorCollection* coll, + std::vector parameters_data, + std::string device, ComputationCache::TypePtr cached_computation) + : mwait(1), + indices(std::move(coll->indices)), + unlocker(std::move(coll->unlocker)), + parameters_data(std::move(parameters_data)), + device(std::move(device)), + cached_computation(std::move(cached_computation)) { + tensors_data.reserve(indices.size()); + } + + xla::xla_util::MultiWait mwait; + std::vector indices; + std::vector unlocker; + std::vector parameters_data; + std::string device; + ComputationCache::TypePtr cached_computation; + std::vector tensors_data; + }; + + struct SyncTensorsConfig { + // Whether we want to force XLA data on the target tensors (hence trimming + // the IR graph above them). + bool force_xla_data = false; + }; + // The context used by the ApplyPendingGraph() API, in order to allow it speed // up operations in case the new tensors graph apply matches the one stored // within the apply context. struct ApplyContext { + ApplyContext() = default; + ApplyContext( + std::vector> + computations, + std::vector uid_order, + std::vector> input_mapping, + std::vector> index_mapping, + std::vector devices) + : computations(std::move(computations)), + uid_order(std::move(uid_order)), + input_mapping(std::move(input_mapping)), + index_mapping(std::move(index_mapping)), + devices(std::move(devices)) {} + std::vector> computations; std::vector uid_order; @@ -957,16 +1000,26 @@ class XLATensor { static SyncTensorCollection CollectSyncTensors( const std::vector& tensors); + // Gathers the XLA device data for all the input tensors, after an + // asynchronous operation. + static std::vector GatherTensorsXlaData( + const std::vector& tensors, std::shared_ptr async); + // Schedules the execution of a sync tensors operation in background. The // asynchronous operation will hold the device locks by capturing the ones // present within the coll structure. - static void ScheduleSyncTensorsGraph( - std::vector* tensors, SyncTensorCollection* coll, + static std::shared_ptr ScheduleSyncTensorsGraph( + std::vector* tensors, const SyncTensorsConfig& config, + SyncTensorCollection* coll, std::vector parameters_data, std::string device, ComputationCache::TypePtr cached_computation); - static bool TryRunCachedSync(std::vector* tensors, - SyncTensorCollection* coll); + static std::shared_ptr TryRunCachedSync( + std::vector* tensors, const SyncTensorsConfig& config, + SyncTensorCollection* coll); + + static std::shared_ptr SyncTensorsGraphInternal( + std::vector* tensors, const SyncTensorsConfig& config); // Returns a permutation which represents an ordering by tensor device and // unique ID, of all the tensors which needs sync (the ones which have a graph