From 8fb92c84359d6b3099272d93b8e19c369ebc173c Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Sat, 16 Mar 2019 11:34:27 -0700 Subject: [PATCH] Overlap XLA tensor sync with future XLA tensor work. Prepare XRT session so that they do not fall into the dreaded one-more-node-in-graph TF case. Reserve XRT sessions just for allocations (sending data to device). --- setup.py | 1 + test/cpp/CMakeLists.txt | 1 + test/cpp/test_replication.cpp | 6 +- test/cpp/test_xla_util_cache.cpp | 33 ++ test/test_operations.py | 4 + third_party/xla_client/cache.h | 69 ++-- third_party/xla_client/computation_client.cc | 8 +- third_party/xla_client/computation_client.h | 30 +- third_party/xla_client/metrics.h | 5 + third_party/xla_client/multi_wait.cc | 2 +- third_party/xla_client/util.h | 31 ++ .../xla_client/xrt_computation_client.cc | 166 +++++++--- .../xla_client/xrt_computation_client.h | 82 +++-- third_party/xla_client/xrt_session_cache.cc | 12 +- third_party/xla_client/xrt_session_cache.h | 3 + torch_xla/csrc/lowering_context.cpp | 9 - torch_xla/csrc/lowering_context.h | 7 +- torch_xla/csrc/module.cpp | 8 +- torch_xla/csrc/module.h | 2 +- torch_xla/csrc/tensor.cpp | 300 ++++++++++++++---- torch_xla/csrc/tensor.h | 52 ++- 21 files changed, 602 insertions(+), 229 deletions(-) create mode 100644 test/cpp/test_xla_util_cache.cpp diff --git a/setup.py b/setup.py index bcfbc6dfac8d..04171c02ccdb 100644 --- a/setup.py +++ b/setup.py @@ -162,6 +162,7 @@ def make_relative_rpath(path): extra_compile_args = [ + '-std=c++14', '-Wno-sign-compare', '-Wno-deprecated-declarations', '-Wno-return-type', diff --git a/test/cpp/CMakeLists.txt b/test/cpp/CMakeLists.txt index 01f3d01eeea6..65f7b4e01b50 100644 --- a/test/cpp/CMakeLists.txt +++ b/test/cpp/CMakeLists.txt @@ -56,6 +56,7 @@ set(TORCH_XLA_TEST_SOURCES test_mayberef.cpp test_replication.cpp test_tensor.cpp + test_xla_util_cache.cpp torch_xla_test.cpp ) diff --git a/test/cpp/test_replication.cpp b/test/cpp/test_replication.cpp index 954d2abfdb5d..29972af07a50 100644 --- a/test/cpp/test_replication.cpp +++ b/test/cpp/test_replication.cpp @@ -49,14 +49,14 @@ void TestSingleReplication(const std::vector& devices) { } auto tensors_data = CreateTensorsData(tensors, device_strings); - std::vector>> - results(device_strings.size()); + std::vector> results( + device_strings.size()); xla::xla_util::MultiWait mwait(device_strings.size()); xla::ComputationClient::ExecuteComputationOptions exec_options; for (size_t i = 0; i < device_strings.size(); ++i) { auto executor = [&, i]() { results[i] = xla::ComputationClient::Get()->ExecuteComputation( - *compiled_computations[i], {tensors_data[i].get()}, device_strings[i], + *compiled_computations[i], {tensors_data[i]}, device_strings[i], exec_options); }; xla::xla_env::ScheduleIoClosure(mwait.Completer(std::move(executor))); diff --git a/test/cpp/test_xla_util_cache.cpp b/test/cpp/test_xla_util_cache.cpp new file mode 100644 index 000000000000..f29c8c8c0d84 --- /dev/null +++ b/test/cpp/test_xla_util_cache.cpp @@ -0,0 +1,33 @@ +#include + +#include + +#include "cpp_test_util.h" +#include "tensorflow/compiler/xla/xla_client/cache.h" +#include "tensorflow/compiler/xla/xla_client/util.h" + +namespace torch_xla { +namespace cpp_test { + +TEST(XlaUtilCacheTest, BasicTest) { + static const size_t kMaxSize = 64; + xla::util::Cache cache(kMaxSize); + + for (int i = 0; i < 2 * kMaxSize; ++i) { + std::string istr = std::to_string(i); + auto ptr = cache.Add(i, std::make_shared(istr)); + ASSERT_NE(ptr, nullptr); + EXPECT_EQ(*ptr, istr); + + ptr = cache.Get(i); + ASSERT_NE(ptr, nullptr); + EXPECT_EQ(*ptr, istr); + } + for (int i = 0; i < kMaxSize - 1; ++i) { + auto ptr = cache.Get(i); + EXPECT_EQ(ptr, nullptr); + } +} + +} // namespace cpp_test +} // namespace torch_xla diff --git a/test/test_operations.py b/test/test_operations.py index 88244770d675..77115816e908 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -634,6 +634,8 @@ def loop_fn(model, loader): model_parallel = dp.DataParallel( XlaMNIST, train_loader, loop_fn, device_ids=devices) model_parallel() + if xu.getenv_as('METRICS_DEBUG', bool, defval=False): + print(torch_xla._XLAC._xla_metrics_report()) class TestParallelTensorResnet18(XlaTestCase): @@ -665,6 +667,8 @@ def loop_fn(model, loader): model_parallel = dp.DataParallel( torchvision.models.resnet18, train_loader, loop_fn, device_ids=devices) model_parallel() + if xu.getenv_as('METRICS_DEBUG', bool, defval=False): + print(torch_xla._XLAC._xla_metrics_report()) class AxPlusB(nn.Module): diff --git a/third_party/xla_client/cache.h b/third_party/xla_client/cache.h index 463f96ca84cd..df54dd60f0bb 100644 --- a/third_party/xla_client/cache.h +++ b/third_party/xla_client/cache.h @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -10,65 +11,49 @@ namespace xla { namespace util { -// Generic key and object cache with LRU expiration policy. +// Generic key and object cache with LRU expiration policy. The objects of type +// T will be stored as std::shared_ptr and taken and returned as such, by the +// cache API. template , typename E = std::equal_to> class Cache { - using Element = std::pair; - using ElementList = std::list; - - struct Hasher { - size_t operator()(const K* key) const { return hasher(*key); } - - H hasher; - }; - - struct Equaler { - bool operator()(const K* k1, const K* k2) const { - return equaler(*k1, *k2); - } - - E equaler; - }; - - using ElementMap = - std::unordered_map; - public: + using TypePtr = std::shared_ptr; + using Element = std::pair; + explicit Cache(size_t max_size) : max_size_(max_size) {} // Adds an object to the cache, unless it already exists. If the cache grows // beyond the limit set during construction, the oldest used object will be // removed from the cache. - void Add(K key, T object) { + TypePtr Add(K key, TypePtr object) { std::lock_guard slock(lock_); element_list_.emplace_front(Element(std::move(key), std::move(object))); auto it = element_list_.begin(); - if (!element_map_.emplace(&it->first, it).second) { + auto emplace_result = element_map_.emplace(&it->first, it); + if (!emplace_result.second) { element_list_.erase(it); + DoLRU(emplace_result.first->second); } else if (element_list_.size() > max_size_) { Element* last = &element_list_.back(); element_map_.erase(&last->first); element_list_.pop_back(); } + return emplace_result.first->second->second; } // Retrieves the existing object if it exists. If it does, it's position in // the LRU list gets moved to the head of the list. // Returns nullptr if no object with the specified key is found within the // cache. - const T* Get(const K& key) { + TypePtr Get(const K& key) { std::lock_guard slock(lock_); auto it = element_map_.find(&key); if (it == element_map_.end()) { return nullptr; } - if (it->second != element_list_.begin()) { - // LRU re-positioning. - element_list_.splice(element_list_.begin(), element_list_, it->second); - } - return &it->second->second; + DoLRU(it->second); + return it->second->second; } bool Erase(const K& key) { @@ -90,6 +75,30 @@ class Cache { } private: + using ElementList = std::list; + + struct Hasher { + size_t operator()(const K* key) const { return hasher(*key); } + + H hasher; + }; + + struct Equaler { + bool operator()(const K* k1, const K* k2) const { + return equaler(*k1, *k2); + } + + E equaler; + }; + + using ElementMap = + std::unordered_map; + + void DoLRU(typename ElementList::iterator it) { + element_list_.splice(element_list_.begin(), element_list_, it); + } + std::mutex lock_; size_t max_size_ = 0; ElementList element_list_; diff --git a/third_party/xla_client/computation_client.cc b/third_party/xla_client/computation_client.cc index dc9d682292b2..1c795d699e90 100644 --- a/third_party/xla_client/computation_client.cc +++ b/third_party/xla_client/computation_client.cc @@ -52,10 +52,11 @@ void AddXrtHostDevices(const string& worker_name, int task_no, XrtComputationClient::Options* options) { struct Devices { const char* name; + const char* tf_name; int count; } const devices[] = { - {"TPU", sys_util::GetEnvInt("TPU_NUM_DEVICES", 8)}, - {"CPU", sys_util::GetEnvInt("CPU_NUM_DEVICES", 1)}, + {"TPU", "TPU", sys_util::GetEnvInt("TPU_NUM_DEVICES", 8)}, + {"CPU", "XLA_CPU", sys_util::GetEnvInt("CPU_NUM_DEVICES", 1)}, }; string host_port = server.compare(0, 7, "grpc://") == 0 ? server @@ -66,9 +67,10 @@ void AddXrtHostDevices(const string& worker_name, int task_no, int& device_ordinal = (*device_ordinals)[device.name]; for (int j = 0; j < device.count; ++j, ++device_ordinal) { string device_name = absl::StrCat(device.name, ":", device_ordinal); + string tf_device_name = absl::StrCat(device.tf_name, ":", device_ordinal); string xrt_device_name = absl::StrCat("/job:", worker_name, "/replica:0/task:", task_no, - "/device:", device_name); + "/device:", tf_device_name); options->device_map.emplace(device_name, xrt_device_name); } } diff --git a/third_party/xla_client/computation_client.h b/third_party/xla_client/computation_client.h index b0277b870627..2ff1b0fba37b 100644 --- a/third_party/xla_client/computation_client.h +++ b/third_party/xla_client/computation_client.h @@ -31,12 +31,16 @@ class ComputationClient { const Shape& shape() const { return shape_; } + virtual void Swap(Data* data) = 0; + private: int64 unique_id_ = 0; string device_; Shape shape_; }; + using DataPtr = std::shared_ptr; + class Computation { public: Computation(XlaComputation computation, ProgramShape program_shape, @@ -105,14 +109,18 @@ class ComputationClient { virtual ~ComputationClient() {} + // Creates a Data object with no actual device handle in it. The device handle + // will be populated in an asynchrounous fashion. + virtual DataPtr CreateDataPlaceholder(string device, Shape shape) = 0; + // Transfers local tensor values to the TPU servers and fetches the handles. - virtual std::vector> TransferToServer( + virtual std::vector TransferToServer( tensorflow::gtl::ArraySlice tensors) = 0; // Reads the tensor literal values stored at TPU server sites, behind the // supplied handles. virtual std::vector TransferFromServer( - tensorflow::gtl::ArraySlice> handles) = 0; + tensorflow::gtl::ArraySlice handles) = 0; // Compiles a set of computations. virtual std::vector> Compile( @@ -122,10 +130,10 @@ class ComputationClient { // The passed device must match the common device of the arguments Data. // If options.explode_tuple is true, the output tuple will be decomposed into // its single elements. - virtual std::vector> ExecuteComputation( + virtual std::vector ExecuteComputation( const Computation& computation, - tensorflow::gtl::ArraySlice arguments, const string& device, - const ExecuteComputationOptions& options) = 0; + tensorflow::gtl::ArraySlice arguments, + const string& device, const ExecuteComputationOptions& options) = 0; // Executes the computation in replicated mode. // The size of the arguments vector is the number of replicas to execute, @@ -138,9 +146,9 @@ class ComputationClient { // The result[i], a vector itself, will be the result of the computation fed // with arguments[i]. If options.explode_tuple is true, the output tuples will // be decomposed into their single elements. - virtual std::vector>> ExecuteReplicated( + virtual std::vector> ExecuteReplicated( const Computation& computation, - const std::vector>& arguments, + const std::vector>& arguments, tensorflow::gtl::ArraySlice devices, const ExecuteReplicatedOptions& options) = 0; @@ -151,14 +159,14 @@ class ComputationClient { // Returns a vector of vectors of device side Data object, with result[i] // being the return value of computations[i]. If options.explode_tuple is // true, the output tuples will be decomposed into their single elements. - virtual std::vector>> ExecuteParallel( + virtual std::vector> ExecuteParallel( tensorflow::gtl::ArraySlice computations, - const std::vector>& arguments, + const std::vector>& arguments, tensorflow::gtl::ArraySlice devices, const ExecuteParallelOptions& options) = 0; - virtual std::vector>> DeconstructTuple( - tensorflow::gtl::ArraySlice> tuples) = 0; + virtual std::vector> DeconstructTuple( + tensorflow::gtl::ArraySlice tuples) = 0; virtual string GetDefaultDevice() const = 0; diff --git a/third_party/xla_client/metrics.h b/third_party/xla_client/metrics.h index 3666cdb44d81..43e260c38e38 100644 --- a/third_party/xla_client/metrics.h +++ b/third_party/xla_client/metrics.h @@ -170,6 +170,11 @@ class TimedSection { int64 start_; }; +#define XLA_TIMED(name) \ + static xla::metrics::Metric* timed_metric = \ + new xla::metrics::Metric(name, xla::metrics::MetricFnTime); \ + xla::metrics::TimedSection timed_section(timed_metric) + } // namespace metrics } // namespace xla diff --git a/third_party/xla_client/multi_wait.cc b/third_party/xla_client/multi_wait.cc index 3f54eecc1340..3adbcd0814f0 100644 --- a/third_party/xla_client/multi_wait.cc +++ b/third_party/xla_client/multi_wait.cc @@ -36,7 +36,7 @@ void MultiWait::Reset(size_t count) { } std::function MultiWait::Completer(std::function func) { - auto completer = [this, func{std::move(func)}]() { + auto completer = [this, func = std::move(func)]() { try { func(); Done(); diff --git a/third_party/xla_client/util.h b/third_party/xla_client/util.h index 477b8d722781..66bed8e9b095 100644 --- a/third_party/xla_client/util.h +++ b/third_party/xla_client/util.h @@ -8,12 +8,43 @@ #include #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/hash/hash.h" namespace xla { namespace util { +class Cleanup { + public: + explicit Cleanup(std::function func) : func_(std::move(func)) {} + Cleanup(Cleanup&& ref) : func_(std::move(ref.func_)) {} + Cleanup(const Cleanup&) = delete; + + ~Cleanup() { + if (func_ != nullptr) { + func_(std::move(status_)); + } + } + + Cleanup& operator=(const Cleanup&) = delete; + + Cleanup& operator=(Cleanup&& ref) { + if (this != &ref) { + func_ = std::move(ref.func_); + } + return *this; + } + + void Release() { func_ = nullptr; } + + void SetStatus(Status status) { status_ = std::move(status); } + + private: + std::function func_; + Status status_; +}; + // Allows APIs which might return const references and values, to not be forced // to return values in the signature. template diff --git a/third_party/xla_client/xrt_computation_client.cc b/third_party/xla_client/xrt_computation_client.cc index a52f60cafe92..6c0cb7947037 100644 --- a/third_party/xla_client/xrt_computation_client.cc +++ b/third_party/xla_client/xrt_computation_client.cc @@ -8,7 +8,6 @@ #include "absl/types/optional.h" #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" #include "tensorflow/compiler/xla/xla_client/multi_wait.h" #include "tensorflow/compiler/xla/xla_client/sys_util.h" #include "tensorflow/compiler/xla/xla_client/thread_pool.h" @@ -24,9 +23,22 @@ thread_local std::vector g_replication_devices; } // namespace +void XrtComputationClient::XrtData::Swap(Data* data) { + XrtData* xrt_data = dynamic_cast(data); + XLA_CHECK(xrt_data != nullptr); + if (xrt_data != this) { + // This requires no locking as this is managed by the caller, by properly + // locking device-wide operations. + std::swap(handle, xrt_data->handle); + std::swap(valid, xrt_data->valid); + } +} + XrtComputationClient::XrtComputationClient( XrtComputationClient::Options options) : options_(std::move(options)), + session_cache_([this](XrtSession* s) { InitSession(s); }), + alloc_session_cache_(nullptr), compilation_cache_(sys_util::GetEnvInt("XLA_COMPILATION_CACHE_SIZE", 64)), rng_seed_(0x5a2d296e9) { auto default_device_target = @@ -42,8 +54,12 @@ XrtComputationClient::XrtComputationClient( StartHandleReleaser(); } -std::vector> -XrtComputationClient::TransferToServer( +ComputationClient::DataPtr XrtComputationClient::CreateDataPlaceholder( + string device, Shape shape) { + return std::make_shared(this, std::move(device), std::move(shape)); +} + +std::vector XrtComputationClient::TransferToServer( tensorflow::gtl::ArraySlice tensors) { metrics::TimedSection timed(TransferToServerMetric()); @@ -65,7 +81,8 @@ XrtComputationClient::TransferToServer( { std::lock_guard slock(lock); - XrtSession* session = GetSessionForXrtDevice(xrt_device, &session_map); + XrtSession* session = GetSessionForXrtDevice(&alloc_session_cache_, + xrt_device, &session_map); SessionWork* session_work = &session_work_map[session]; tensorflow::Scope device_scope = session->root()->WithDevice(xrt_device); @@ -84,7 +101,7 @@ XrtComputationClient::TransferToServer( OutboundDataMetric()->AddSample(total_size); - std::vector> results(tensors.size()); + std::vector results(tensors.size()); for (auto& session_work : session_work_map) { std::vector outputs; XLA_CHECK_OK(session_work.first->session()->Run( @@ -104,20 +121,22 @@ XrtComputationClient::TransferToServer( } std::vector XrtComputationClient::TransferFromServer( - tensorflow::gtl::ArraySlice> handles) { + tensorflow::gtl::ArraySlice handles) { metrics::TimedSection timed(TransferFromServerMetric()); XrtSessionCache::SessionMap session_map; std::map session_work_map; for (size_t i = 0; i < handles.size(); ++i) { const XrtData& xrt_data = dynamic_cast(*handles[i]); - XrtSession* session = GetSessionForDevice(xrt_data.device(), &session_map); + XrtSession* session = + GetSessionForDevice(&session_cache_, xrt_data.device(), &session_map); SessionWork* session_work = &session_work_map[session]; tensorflow::Scope device_scope = session->root()->WithDevice(TorchDeviceToXrtDevice(xrt_data.device())); const XrtSession::CachedNode& cached_node = GetReadNode(session, device_scope, xrt_data.device()); - session_work->feed_inputs.insert({cached_node.holders[0], xrt_data.handle}); + session_work->feed_inputs.insert( + {cached_node.holders[0], xrt_data.get_handle()}); session_work->outputs_handles.push_back(cached_node.outputs[0]); session_work->index_mapping.push_back(i); } @@ -173,7 +192,7 @@ XrtComputationClient::Compile(std::vector instances) { { std::lock_guard slock(lock); XrtSession* session = - GetSessionForXrtDevice(xrt_device, &session_map); + GetSessionForXrtDevice(&session_cache_, xrt_device, &session_map); SessionWork* session_work = &session_work_map[session]; tensorflow::Scope device_scope = session->root()->WithDevice(xrt_device); @@ -185,7 +204,7 @@ XrtComputationClient::Compile(std::vector instances) { session_work->index_mapping.push_back(i); } } else { - results[i] = *computation_ptr; + results[i] = computation_ptr; } }; xla_env::ScheduleClosure(mwait.Completer(std::move(builder))); @@ -224,10 +243,10 @@ XrtComputationClient::Compile(std::vector instances) { return results; } -std::vector> +std::vector XrtComputationClient::ExecuteComputation( const Computation& computation, - tensorflow::gtl::ArraySlice arguments, const string& device, + tensorflow::gtl::ArraySlice arguments, const string& device, const ExecuteComputationOptions& options) { metrics::TimedSection timed(ExecuteMetric()); @@ -239,7 +258,8 @@ XrtComputationClient::ExecuteComputation( BuildParallelArguments(arguments), options.explode_tuple, {effective_device}, &feed_inputs); - XrtSession* session = GetSessionForDevice(effective_device, &session_map); + XrtSession* session = + GetSessionForDevice(&session_cache_, effective_device, &session_map); std::vector outputs; xrt_util::CheckComputationStatus( session->session()->Run(feed_inputs, {exec_ops.front()}, &outputs), @@ -250,10 +270,10 @@ XrtComputationClient::ExecuteComputation( effective_device); } -std::vector>> +std::vector> XrtComputationClient::ExecuteReplicated( const Computation& computation, - const std::vector>& arguments, + const std::vector>& arguments, tensorflow::gtl::ArraySlice devices, const ExecuteReplicatedOptions& options) { metrics::TimedSection timed(ExecuteReplicatedMetric()); @@ -270,7 +290,7 @@ XrtComputationClient::ExecuteReplicated( feed_inputs); } -std::vector>> +std::vector> XrtComputationClient::RunComputations( const XrtSessionCache::SessionMap& session_map, const std::vector& exec_ops, @@ -299,7 +319,7 @@ XrtComputationClient::RunComputations( XLA_CHECK_EQ(computations.size(), devices.size()); xla_util::MultiWait mwait(session_replicas.size()); - std::vector>> results(devices.size()); + std::vector> results(devices.size()); for (auto& sess_replica : session_replicas) { XrtSession* session = sess_replica.first; const std::vector& replicas = sess_replica.second; @@ -330,10 +350,10 @@ XrtComputationClient::RunComputations( return results; } -std::vector>> +std::vector> XrtComputationClient::ExecuteParallel( tensorflow::gtl::ArraySlice computations, - const std::vector>& arguments, + const std::vector>& arguments, tensorflow::gtl::ArraySlice devices, const ExecuteParallelOptions& options) { metrics::TimedSection timed(ExecuteParallelMetric()); @@ -347,9 +367,9 @@ XrtComputationClient::ExecuteParallel( feed_inputs); } -std::vector>> +std::vector> XrtComputationClient::DeconstructTuple( - tensorflow::gtl::ArraySlice> tuples) { + tensorflow::gtl::ArraySlice tuples) { metrics::TimedSection timed(DeconstructTupleMetric()); XrtSessionCache::SessionMap session_map; @@ -357,7 +377,8 @@ XrtComputationClient::DeconstructTuple( std::vector tuple_elements_count(tuples.size()); for (size_t i = 0; i < tuples.size(); ++i) { const XrtData& xrt_data = dynamic_cast(*tuples[i]); - XrtSession* session = GetSessionForDevice(xrt_data.device(), &session_map); + XrtSession* session = + GetSessionForDevice(&session_cache_, xrt_data.device(), &session_map); SessionWork* session_work = &session_work_map[session]; session_work->index_mapping.push_back(i); @@ -369,7 +390,7 @@ XrtComputationClient::DeconstructTuple( const XrtSession::CachedNode& cached_node = GetSubTupleNode(session, device_scope, xrt_data.device()); session_work->feed_inputs.insert( - {cached_node.holders[0], xrt_data.handle}); + {cached_node.holders[0], xrt_data.get_handle()}); tensorflow::Tensor index_tensor(tensorflow::DT_INT32, tensorflow::TensorShape({1})); index_tensor.flat()(0) = j; @@ -378,7 +399,7 @@ XrtComputationClient::DeconstructTuple( } } - std::vector>> results(tuples.size()); + std::vector> results(tuples.size()); for (auto& session_work : session_work_map) { std::vector outputs; XLA_CHECK_OK(session_work.first->session()->Run( @@ -389,7 +410,7 @@ XrtComputationClient::DeconstructTuple( size_t output_index = 0; for (auto li : session_work.second.index_mapping) { const XrtData& xrt_data = dynamic_cast(*tuples[li]); - std::vector> tuple_results; + std::vector tuple_results; for (size_t i = 0; i < tuple_elements_count[li]; ++i, ++output_index) { tuple_results.push_back(std::make_shared( this, xrt_data.device(), @@ -404,19 +425,23 @@ XrtComputationClient::DeconstructTuple( } XrtSession* XrtComputationClient::GetSessionForTarget( - const string& target, XrtSessionCache::SessionMap* session_map) { - return session_cache_.GetSession(target, session_map); + XrtSessionCache* cache, const string& target, + XrtSessionCache::SessionMap* session_map) { + return cache->GetSession(target, session_map); } XrtSession* XrtComputationClient::GetSessionForXrtDevice( - const string& xrt_device, XrtSessionCache::SessionMap* session_map) { + XrtSessionCache* cache, const string& xrt_device, + XrtSessionCache::SessionMap* session_map) { auto worker_hostport = GetWorkerForXrtDevice(xrt_device); - return GetSessionForTarget(worker_hostport.second, session_map); + return GetSessionForTarget(cache, worker_hostport.second, session_map); } XrtSession* XrtComputationClient::GetSessionForDevice( - const string& device, XrtSessionCache::SessionMap* session_map) { - return GetSessionForXrtDevice(TorchDeviceToXrtDevice(device), session_map); + XrtSessionCache* cache, const string& device, + XrtSessionCache::SessionMap* session_map) { + return GetSessionForXrtDevice(cache, TorchDeviceToXrtDevice(device), + session_map); } string XrtComputationClient::GetEffectiveDevice(const string& device) const { @@ -484,14 +509,14 @@ std::unique_ptr XrtComputationClient::CreateXrtComputation( } tensorflow::Tensor XrtComputationClient::GetArgumentsInputs( - tensorflow::gtl::ArraySlice arguments, const string& device, + tensorflow::gtl::ArraySlice arguments, const string& device, tensorflow::ClientSession::FeedType* feed_inputs) { tensorflow::Tensor inputs_tensor(tensorflow::DT_INT64, tensorflow::TensorShape({arguments.size()})); for (size_t i = 0; i < arguments.size(); ++i) { - XrtData* xrt_data = dynamic_cast(arguments[i]); - XLA_CHECK_EQ(device, xrt_data->device()); - inputs_tensor.flat()(i) = xrt_data->handle; + const XrtData& xrt_data = dynamic_cast(*arguments[i]); + XLA_CHECK_EQ(device, xrt_data.device()); + inputs_tensor.flat()(i) = xrt_data.get_handle(); } return inputs_tensor; } @@ -499,7 +524,7 @@ tensorflow::Tensor XrtComputationClient::GetArgumentsInputs( std::vector XrtComputationClient::CreateExecuteOps( XrtSessionCache::SessionMap* session_map, tensorflow::gtl::ArraySlice computations, - const std::vector>& arguments, bool explode_tuple, + const std::vector>& arguments, bool explode_tuple, tensorflow::gtl::ArraySlice devices, tensorflow::ClientSession::FeedType* feed_inputs) { std::vector exec_ops; @@ -508,7 +533,8 @@ std::vector XrtComputationClient::CreateExecuteOps( dynamic_cast(computations[i]); auto inputs = GetArgumentsInputs(arguments[i], devices[i], feed_inputs); const string& xrt_device = TorchDeviceToXrtDevice(devices[i]); - XrtSession* session = GetSessionForXrtDevice(xrt_device, session_map); + XrtSession* session = + GetSessionForXrtDevice(&session_cache_, xrt_device, session_map); tensorflow::Scope device_scope = session->root()->WithDevice(xrt_device); const XrtSession::CachedNode& cached_node = GetExecuteNode(session, device_scope, devices[i]); @@ -531,18 +557,19 @@ std::vector XrtComputationClient::CreateExecuteOps( std::vector XrtComputationClient::CreateExecuteOps( XrtSessionCache::SessionMap* session_map, const XrtComputation& computation, - const std::vector>& arguments, bool explode_tuple, + const std::vector>& arguments, bool explode_tuple, tensorflow::gtl::ArraySlice devices, tensorflow::ClientSession::FeedType* feed_inputs) { std::vector exec_ops; for (size_t i = 0; i < arguments.size(); ++i) { auto inputs = GetArgumentsInputs(arguments[i], devices[i], feed_inputs); const string& xrt_device = TorchDeviceToXrtDevice(devices[i]); - XrtSession* session = GetSessionForXrtDevice(xrt_device, session_map); + XrtSession* session = + GetSessionForXrtDevice(&session_cache_, xrt_device, session_map); tensorflow::Scope device_scope = session->root()->WithDevice(xrt_device); const XrtSession::CachedNode& cached_node = GetExecuteNode(session, device_scope, devices[i]); - feed_inputs->insert({cached_node.holders[0], computation.handle}); + feed_inputs->insert({cached_node.holders[0], computation.get_handle()}); xrt::XRTExecutionConfig exec_config; exec_config.set_core_index_in_replica(0); @@ -575,7 +602,8 @@ void XrtComputationClient::ReleaseHandles( XrtSessionCache::SessionMap session_map; std::map> session_handles_map; for (auto& handle : released_handles) { - XrtSession* session = GetSessionForDevice(handle.device, &session_map); + XrtSession* session = + GetSessionForDevice(&session_cache_, handle.device, &session_map); session_handles_map[session].push_back(handle); } for (const auto& session_and_handles : session_handles_map) { @@ -638,7 +666,7 @@ bool XrtComputationClient::ReleaseHandle(XrtHandle* handle, std::lock_guard lock(lock_); absl::optional opt_handle = handle->Release(); if (opt_handle) { - handles->emplace_back(device, *opt_handle); + handles->push_back({device, *opt_handle}); released = true; } } @@ -709,8 +737,8 @@ tensorflow::tpu::TopologyProto XrtComputationClient::InitializeAndFetchTopology( "/replica:0/task:", worker_hostport.first.task_no, "/device:TPU_SYSTEM:0"); XrtSessionCache::SessionMap session_map; - XrtSession* session = - GetSessionForTarget(worker_hostport.second, &session_map); + XrtSession* session = GetSessionForTarget( + &session_cache_, worker_hostport.second, &session_map); tensorflow::Scope tpu_system_scope = session->root()->WithDevice(system_device); const auto unique_name = @@ -781,11 +809,11 @@ void XrtComputationClient::InitializeDevices() { } } -std::vector> +std::vector XrtComputationClient::GetComputationResults( const tensorflow::Tensor& xrt_result, const Shape& result_shape, const string& device) { - std::vector> results; + std::vector results; if (xrt_result.dims() == 1) { auto handles_vec = xrt_result.vec(); for (int64 i = 0; i < handles_vec.size(); ++i) { @@ -827,6 +855,39 @@ const std::vector& XrtComputationClient::GetReplicationDevices() const { void XrtComputationClient::SetRngSeed(size_t seed) { rng_seed_ = seed; } +void XrtComputationClient::InitSession(XrtSession* session) const { + struct InitNode { + int count; + const XrtSession::CachedNode& (XrtComputationClient::*node_ctor)( + XrtSession*, const tensorflow::Scope&, const string&)const; + } const init_nodes[] = { + {16, &XrtComputationClient::GetCompileNode}, + {16, &XrtComputationClient::GetExecuteNode}, + {16, &XrtComputationClient::GetReadNode}, + {16, &XrtComputationClient::GetReleaseAllocationHandleNode}, + {16, &XrtComputationClient::GetReleaseCompileHandleNode}, + {16, &XrtComputationClient::GetSubTupleNode}, + }; + auto devices = GetAvailableDevices(); + for (auto& device : devices) { + // HACK: The XRT ops on the remote GRPC service has only recently been + // enabled, so until TF 1.14 is out, we cannot add XRT ops on CPU. + // If there is only one device, even if CPU, this is the local session, + // which carries the XRT op (as we include them in the BUILD). + if (device.compare(0, 4, "CPU:") == 0 && devices.size() > 1) { + continue; + } + const string& xrt_device = TorchDeviceToXrtDevice(device); + tensorflow::Scope device_scope = session->root()->WithDevice(xrt_device); + for (auto& init : init_nodes) { + for (int i = 0; i < init.count; ++i) { + (this->*init.node_ctor)(session, device_scope, device); + } + } + } + session->Reset(); +} + const XrtSession::CachedNode& XrtComputationClient::GetCompileNode( XrtSession* session, const tensorflow::Scope& scope, const string& device) const { @@ -834,6 +895,7 @@ const XrtSession::CachedNode& XrtComputationClient::GetCompileNode( XrtSession::NodeCache* cache = session->GetNodeCache(XrtSession::GetCacheKey(op_name, device)); if (cache->Empty()) { + XLA_COUNTER("XrtCompile_Empty", 1); std::vector holders( {tensorflow::ops::Placeholder(scope, tensorflow::DT_STRING)}); cache->Add(std::make_shared( @@ -850,6 +912,7 @@ const XrtSession::CachedNode& XrtComputationClient::GetExecuteNode( XrtSession::NodeCache* cache = session->GetNodeCache(XrtSession::GetCacheKey(op_name, device)); if (cache->Empty()) { + XLA_COUNTER("XrtExecute_Empty", 1); std::vector holders( {tensorflow::ops::Placeholder(scope, tensorflow::DT_INT64), tensorflow::ops::Placeholder(scope, tensorflow::DT_STRING), @@ -871,6 +934,7 @@ const XrtSession::CachedNode& XrtComputationClient::GetReadNode( XrtSession::NodeCache* cache = session->GetNodeCache(XrtSession::GetCacheKey(op_name, device)); if (cache->Empty()) { + XLA_COUNTER("XrtRead_Empty", 1); std::vector holders( {tensorflow::ops::Placeholder(scope, tensorflow::DT_INT64)}); cache->Add(std::make_shared( @@ -890,6 +954,7 @@ const XrtSession::CachedNode& XrtComputationClient::GetAllocateNode( XrtSession::NodeCache* cache = session->GetNodeCache(XrtSession::GetCacheKey(ss.str(), device)); if (cache->Empty()) { + XLA_COUNTER("XRTAllocateFromTensor_Empty", 1); tensorflow::TensorShape tensor_shape(shape.dimensions()); tensorflow::TensorShape equiv_tensor_shape = MakeEquivalentTensorShape(shape); @@ -917,6 +982,7 @@ XrtComputationClient::GetReleaseAllocationHandleNode( XrtSession::NodeCache* cache = session->GetNodeCache(XrtSession::GetCacheKey(op_name, device)); if (cache->Empty()) { + XLA_COUNTER("XrtReleaseAllocationHandle_Empty", 1); std::vector holders( {tensorflow::ops::Placeholder(scope, tensorflow::DT_INT64)}); cache->Add(std::make_shared( @@ -933,6 +999,7 @@ const XrtSession::CachedNode& XrtComputationClient::GetReleaseCompileHandleNode( XrtSession::NodeCache* cache = session->GetNodeCache(XrtSession::GetCacheKey(op_name, device)); if (cache->Empty()) { + XLA_COUNTER("XrtReleaseCompileHandle_Empty", 1); std::vector holders( {tensorflow::ops::Placeholder(scope, tensorflow::DT_INT64)}); cache->Add(std::make_shared( @@ -949,6 +1016,7 @@ const XrtSession::CachedNode& XrtComputationClient::GetSubTupleNode( XrtSession::NodeCache* cache = session->GetNodeCache(XrtSession::GetCacheKey(op_name, device)); if (cache->Empty()) { + XLA_COUNTER("XrtSubTuple_Empty", 1); std::vector holders( {tensorflow::ops::Placeholder(scope, tensorflow::DT_INT64), tensorflow::ops::Placeholder( @@ -1000,10 +1068,10 @@ tensorflow::TensorShape XrtComputationClient::MakeEquivalentTensorShape( return tensorflow::TensorShape(eqiv_shape.dimensions()); } -std::vector> +std::vector> XrtComputationClient::BuildParallelArguments( - tensorflow::gtl::ArraySlice arguments) { - std::vector> para_arguments(1); + tensorflow::gtl::ArraySlice arguments) { + std::vector> para_arguments(1); para_arguments[0].insert(para_arguments[0].end(), arguments.begin(), arguments.end()); return para_arguments; diff --git a/third_party/xla_client/xrt_computation_client.h b/third_party/xla_client/xrt_computation_client.h index 4283bf77d85f..d75a6963ff7f 100644 --- a/third_party/xla_client/xrt_computation_client.h +++ b/third_party/xla_client/xrt_computation_client.h @@ -16,6 +16,7 @@ #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/xla/xla_client/cache.h" #include "tensorflow/compiler/xla/xla_client/computation_client.h" +#include "tensorflow/compiler/xla/xla_client/debug_macros.h" #include "tensorflow/compiler/xla/xla_client/metrics.h" #include "tensorflow/compiler/xla/xla_client/triggered_task.h" #include "tensorflow/compiler/xla/xla_client/util.h" @@ -33,40 +34,49 @@ namespace xla { class XrtComputationClient : public ComputationClient { struct DeviceHandle { - DeviceHandle(string device, int64 handle) - : device(std::move(device)), handle(handle) {} - string device; int64 handle; }; struct XrtHandle { + explicit XrtHandle(XrtComputationClient* self) + : self(self), handle(0), valid(false) {} XrtHandle(XrtComputationClient* self, int64 handle) - : self(self), handle(handle), released(false) {} + : self(self), handle(handle), valid(true) {} absl::optional Release() { - if (released.exchange(true)) { + if (!valid) { return absl::nullopt; } + valid = false; + return handle; + } + + int64 get_handle() const { + XLA_CHECK(valid) << "Invalid handle: " << handle; return handle; } XrtComputationClient* self; int64 handle; - std::atomic released; + bool valid; }; struct XrtData : public Data, public XrtHandle { + XrtData(XrtComputationClient* self, string device, Shape device_shape) + : Data(std::move(device), std::move(device_shape)), XrtHandle(self) {} XrtData(XrtComputationClient* self, string device, Shape device_shape, int64 handle) : Data(std::move(device), std::move(device_shape)), XrtHandle(self, handle) {} ~XrtData() override { - if (!released) { + if (valid) { self->ReleaseXrtData(this); } } + + void Swap(Data* data) override; }; struct XrtComputation : public Computation, public XrtHandle { @@ -79,7 +89,7 @@ class XrtComputationClient : public ComputationClient { compilation_device(std::move(compilation_device)) {} ~XrtComputation() override { - if (!released) { + if (valid) { self->ReleaseXrtComputation(this); } } @@ -116,35 +126,36 @@ class XrtComputationClient : public ComputationClient { XrtComputationClient(Options options); - std::vector> TransferToServer( + DataPtr CreateDataPlaceholder(string device, Shape shape) override; + + std::vector TransferToServer( tensorflow::gtl::ArraySlice tensors) override; std::vector TransferFromServer( - tensorflow::gtl::ArraySlice> handles) - override; + tensorflow::gtl::ArraySlice handles) override; std::vector> Compile( std::vector instances) override; - std::vector> ExecuteComputation( + std::vector ExecuteComputation( const Computation& computation, - tensorflow::gtl::ArraySlice arguments, const string& device, - const ExecuteComputationOptions& options) override; + tensorflow::gtl::ArraySlice arguments, + const string& device, const ExecuteComputationOptions& options) override; - std::vector>> ExecuteReplicated( + std::vector> ExecuteReplicated( const Computation& computation, - const std::vector>& arguments, + const std::vector>& arguments, tensorflow::gtl::ArraySlice devices, const ExecuteReplicatedOptions& options) override; - std::vector>> ExecuteParallel( + std::vector> ExecuteParallel( tensorflow::gtl::ArraySlice computations, - const std::vector>& arguments, + const std::vector>& arguments, tensorflow::gtl::ArraySlice devices, const ExecuteParallelOptions& options) override; - std::vector>> DeconstructTuple( - tensorflow::gtl::ArraySlice> tuples) override; + std::vector> DeconstructTuple( + tensorflow::gtl::ArraySlice tuples) override; string GetDefaultDevice() const override; @@ -168,11 +179,12 @@ class XrtComputationClient : public ComputationClient { std::vector index_mapping; }; - XrtSession* GetSessionForTarget(const string& target, + XrtSession* GetSessionForTarget(XrtSessionCache* cache, const string& target, XrtSessionCache::SessionMap* session_map); - XrtSession* GetSessionForXrtDevice(const string& xrt_device, + XrtSession* GetSessionForXrtDevice(XrtSessionCache* cache, + const string& xrt_device, XrtSessionCache::SessionMap* session_map); - XrtSession* GetSessionForDevice(const string& device, + XrtSession* GetSessionForDevice(XrtSessionCache* cache, const string& device, XrtSessionCache::SessionMap* session_map); string GetEffectiveDevice(const string& device) const; @@ -188,24 +200,24 @@ class XrtComputationClient : public ComputationClient { const Shape* output_shape) const; tensorflow::Tensor GetArgumentsInputs( - tensorflow::gtl::ArraySlice arguments, const string& device, - tensorflow::ClientSession::FeedType* feed_inputs); + tensorflow::gtl::ArraySlice arguments, + const string& device, tensorflow::ClientSession::FeedType* feed_inputs); std::vector CreateExecuteOps( XrtSessionCache::SessionMap* session_map, tensorflow::gtl::ArraySlice computations, - const std::vector>& arguments, bool explode_tuple, + const std::vector>& arguments, bool explode_tuple, tensorflow::gtl::ArraySlice devices, tensorflow::ClientSession::FeedType* feed_inputs); std::vector CreateExecuteOps( XrtSessionCache::SessionMap* session_map, const XrtComputation& computation, - const std::vector>& arguments, bool explode_tuple, + const std::vector>& arguments, bool explode_tuple, tensorflow::gtl::ArraySlice devices, tensorflow::ClientSession::FeedType* feed_inputs); - std::vector>> RunComputations( + std::vector> RunComputations( const XrtSessionCache::SessionMap& session_map, const std::vector& exec_ops, tensorflow::gtl::ArraySlice computations, @@ -214,7 +226,7 @@ class XrtComputationClient : public ComputationClient { // Retrieves the worker,worker_host pair for a given PyTorch device (ie, // TPU:0). - std::pair GetWorkerForDevice(const string& xrt_device) const; + std::pair GetWorkerForDevice(const string& device) const; // Retrieves the worker,worker_host pair for a given XRT device (ie, // /job:tpu_worker/replica:0/task:0/device:TPU:0). @@ -249,10 +261,12 @@ class XrtComputationClient : public ComputationClient { void InitializeDevices(); - std::vector> GetComputationResults( + std::vector GetComputationResults( const tensorflow::Tensor& xrt_result, const Shape& result_shape, const string& device); + void InitSession(XrtSession* session) const; + // Creates an XRT graph with an XRTCompile operation: // // XRTCompile( @@ -352,8 +366,8 @@ class XrtComputationClient : public ComputationClient { // Builds an argument vector usable in a replicated context, out of a single // replica argument vector. Essentially turns a [N] into a [1][N]. - static std::vector> BuildParallelArguments( - tensorflow::gtl::ArraySlice arguments); + static std::vector> BuildParallelArguments( + tensorflow::gtl::ArraySlice arguments); // Extracts the XlaComputation pointers out of Computation ones. Used to be // passed to xrt_util::CheckComputationStatus() for its error reporting. @@ -368,9 +382,9 @@ class XrtComputationClient : public ComputationClient { std::mutex lock_; std::map> device_mesh_coords_; XrtSessionCache session_cache_; + XrtSessionCache alloc_session_cache_; std::unique_ptr triggered_task_; - util::Cache, - util::PartialHasher> + util::Cache> compilation_cache_; std::atomic rng_seed_; // Access to the following members must be done while holding lock_. diff --git a/third_party/xla_client/xrt_session_cache.cc b/third_party/xla_client/xrt_session_cache.cc index 3bf99e95a72d..b8a70cd1f2ba 100644 --- a/third_party/xla_client/xrt_session_cache.cc +++ b/third_party/xla_client/xrt_session_cache.cc @@ -1,9 +1,13 @@ #include "tensorflow/compiler/xla/xla_client/xrt_session_cache.h" +#include "tensorflow/compiler/xla/xla_client/metrics.h" #include "tensorflow/compiler/xla/xla_client/sys_util.h" namespace xla { +XrtSessionCache::XrtSessionCache(std::function initfn) + : initfn_(std::move(initfn)) {} + XrtSessionCache::Ref XrtSessionCache::GetSession(const string& target) { std::lock_guard lock(lock_); auto& session_queue = session_map_[target]; @@ -32,6 +36,7 @@ void XrtSessionCache::AddSession(std::shared_ptr session) { std::shared_ptr XrtSessionCache::CreateSession( const string& target) const { + XLA_COUNTER("XrtSessionCount", 1); tensorflow::SessionOptions session_options; session_options.env = tensorflow::Env::Default(); session_options.target = target; @@ -44,7 +49,12 @@ std::shared_ptr XrtSessionCache::CreateSession( rpc_options->set_compression_level( sys_util::GetEnvInt("XRT_GRPC_COMPRESSION_LEVEL", 3)); } - return std::make_shared(session_options); + std::shared_ptr session = + std::make_shared(session_options); + if (initfn_ != nullptr) { + initfn_(session.get()); + } + return session; } } // namespace xla diff --git a/third_party/xla_client/xrt_session_cache.h b/third_party/xla_client/xrt_session_cache.h index e081c232c3e6..7872830876a1 100644 --- a/third_party/xla_client/xrt_session_cache.h +++ b/third_party/xla_client/xrt_session_cache.h @@ -65,6 +65,8 @@ class XrtSessionCache { // Map from session target to XrtSession reference. using SessionMap = std::map; + explicit XrtSessionCache(std::function initfn); + // Retrieves a new session reference, for which the caller will have exclusive // access. Once the reference object is destroyed, the session will be // returned to the cache. @@ -80,6 +82,7 @@ class XrtSessionCache { private: std::shared_ptr CreateSession(const string& target) const; + std::function initfn_; std::mutex lock_; std::map>> session_map_; }; diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index ec7b292bb683..69be6938a12b 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -21,15 +21,6 @@ xla::XlaOp LoweringContext::GetParameter( return it->second; } -std::vector LoweringContext::GetParametersData() - const { - std::vector parameters; - for (auto& param : parameters_) { - parameters.push_back(param.get()); - } - return parameters; -} - xla::int64 LoweringContext::AddResult(xla::XlaOp op) { root_tuple_.push_back(std::move(op)); return root_tuple_.size() - 1; diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index c67ec6460305..153d87d278cb 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -31,7 +31,10 @@ class LoweringContext { // Retrieves the vector holding all the tensors associated with the parameter // instructions which have been created. - std::vector GetParametersData() const; + const std::vector& GetParametersData() + const { + return parameters_; + } // Adds the output of a given operation to the result tuple. xla::int64 AddResult(xla::XlaOp op); @@ -65,7 +68,7 @@ class LoweringContext { const char* error_msg); xla::XlaBuilder builder_; - std::vector> parameters_; + std::vector parameters_; std::unordered_map parameters_map_; std::vector root_tuple_; OutputMap emitted_outputs_; diff --git a/torch_xla/csrc/module.cpp b/torch_xla/csrc/module.cpp index 358c260771e6..ffc370c7f539 100644 --- a/torch_xla/csrc/module.cpp +++ b/torch_xla/csrc/module.cpp @@ -60,8 +60,7 @@ void GatherParameters(std::vector* values, } XlaModule::TensorBatchVector CreateResultBatchVector( - std::vector>> - results) { + std::vector> results) { XlaModule::TensorBatchVector batch_tensors; for (auto& replica_result_components : results) { XlaModule::TensorBatchVector::value_type replica_tensors; @@ -600,8 +599,7 @@ std::vector XlaModule::GetStringDevices() const { XlaModule::TensorBatchVector XlaModule::Execute( const xla::ComputationClient::Computation& computation, const DataBatchVector& inputs) { - std::vector>> - exec_results; + std::vector> exec_results; if (inputs.size() == 1) { xla::ComputationClient::ExecuteComputationOptions options; exec_results.push_back(xla::ComputationClient::Get()->ExecuteComputation( @@ -684,7 +682,7 @@ XlaModule::DataBatchVector XlaModule::GetDataBatchVector( DataBatchVector::value_type replica_inputs_data; for (size_t j = 0; j < replica_inputs.size(); ++j) { if (zero_input == nullptr || !zero_input->at(j)) { - replica_inputs_data.push_back(replica_inputs[j].GetXlaData().get()); + replica_inputs_data.push_back(replica_inputs[j].GetXlaData()); } } inputs_data.push_back(std::move(replica_inputs_data)); diff --git a/torch_xla/csrc/module.h b/torch_xla/csrc/module.h index 73ccc357e2fd..8cafcaf3962e 100644 --- a/torch_xla/csrc/module.h +++ b/torch_xla/csrc/module.h @@ -51,7 +51,7 @@ struct XlaModule : public std::enable_shared_from_this { // The i-th entry in this vector, is a vector of XLA computation data which // belong the i-th replica. using DataBatchVector = - std::vector>; + std::vector>; void Initialize(const TensorBatchVector& inputs); diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index be2355200ab2..b9baa778456e 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -2,8 +2,11 @@ #include #include +#include #include #include +#include +#include #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/xla_client/cache.h" @@ -12,8 +15,8 @@ #include "tensorflow/compiler/xla/xla_client/multi_wait.h" #include "tensorflow/compiler/xla/xla_client/thread_pool.h" #include "tensorflow/compiler/xla/xla_client/unique.h" -#include "tensorflow/compiler/xla/xla_client/util.h" #include "tensorflow/compiler/xla/xla_client/xla_util.h" +#include "tensorflow/core/lib/core/errors.h" #include "torch/csrc/autograd/variable.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ir_util.h" @@ -26,29 +29,119 @@ namespace torch_xla { namespace { -struct CachedComputation { - std::shared_ptr computation; - size_t num_parameters; +// Locking: +// We perform two kinds of operations of tensors, synchronous and asynchronous. +// The ApplyPendingGraph() are synchronous, as we need the device data result +// immediately. Before the synchronous operations can start, they need to wait +// that the pending asynchronous operations have completed. +// Synchronous operations do not hold device locks, since they are strictly +// sequential, dictated by the PyTorch execution order. +// The SyncTensorsGraph() is asynchronous, and returns immediately after having +// scheduled the asynchronous operation. While executing, the asynchronous +// operations will hold locks on all the participating devices (in most common +// cases there will be only one device). +// Since asynchronous operations capture device locks, only one asynchronous +// operation can execute at the same time, on a given device. Tensor operations +// which send data to device do not need to hold any device locks while doing +// so. Only operations which _use_ device data (computations, and transfer from +// server) need to wait for asynchronous operations to complete (barrier). + +class DeviceLocker { + public: + explicit DeviceLocker(Device device) : device_(std::move(device)) {} + + const Device& device() const { return device_; } + + void Lock() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return !locked_; }); + CheckResetStatus(); + locked_ = true; + } + + void Unlock(xla::Status status) { + std::lock_guard lock(mutex_); + locked_ = false; + status_ = std::move(status); + cv_.notify_one(); + } + + void Barrier() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return !locked_; }); + CheckResetStatus(); + } + + private: + void CheckResetStatus() { + xla::Status status = std::move(status_); + status_ = xla::Status::OK(); + if (!status.ok()) { + throw std::runtime_error(status.error_message()); + } + } + + Device device_; + std::mutex mutex_; + std::condition_variable cv_; + bool locked_ = false; + xla::Status status_; +}; + +class DeviceLockerArena { + public: + static DeviceLockerArena* Get() { + static DeviceLockerArena* arena = new DeviceLockerArena(); + return arena; + } + + std::shared_ptr GetLocker(const Device& device) { + std::lock_guard lock(mutex_); + auto it = lockers_.find(device); + if (it == lockers_.end()) { + it = lockers_.emplace(device, std::make_shared(device)) + .first; + } + return it->second; + } + + private: + std::mutex mutex_; + std::map> lockers_; }; -using ComputationCache = xla::util::Cache; +xla::util::Cleanup LockDevice(const Device& device) { + auto locker = DeviceLockerArena::Get()->GetLocker(device); + locker->Lock(); + return xla::util::Cleanup([locker = std::move(locker)](xla::Status status) { + locker->Unlock(std::move(status)); + }); +} + +void DeviceBarrier(const Device& device) { + auto locker = DeviceLockerArena::Get()->GetLocker(device); + locker->Barrier(); +} -ComputationCache* GetComputationCache() { - static const size_t kMaxCacheSize = 128; - static ComputationCache* cache = new ComputationCache(kMaxCacheSize); - return cache; +// Use a set to impose an order on the device locking sequence (ABBA +// prevention). +std::vector LockDevices(const std::set& devices) { + std::vector unlocker; + unlocker.reserve(devices.size()); + for (auto& device : devices) { + unlocker.emplace_back(LockDevice(device)); + } + return unlocker; } void SetMulti(std::vector* dest_tuple, - std::vector> - new_dest_elements, + std::vector new_dest_elements, const std::vector& index_mapping) { XLA_CHECK_EQ(index_mapping.size(), new_dest_elements.size()); // Replace the underlying data for the destination tensors with the data in // "new_dest_elements". for (size_t i = 0; i < new_dest_elements.size(); ++i) { size_t dest_tuple_index = index_mapping[i]; - // Prefer not to make SetXlaData() non-const. (*dest_tuple)[dest_tuple_index].SetXlaData(std::move(new_dest_elements[i])); } } @@ -109,7 +202,7 @@ XLATensor XLATensor::Create(const at::Tensor& tensor, const Device& device, } XLATensor XLATensor::Create( - std::shared_ptr xla_data, bool requires_grad, + xla::ComputationClient::DataPtr xla_data, bool requires_grad, c10::optional logical_element_type) { XLATensor xtensor(std::move(xla_data), requires_grad, logical_element_type); TensorsArena::Get()->RegisterTensor(xtensor.data_ptr()); @@ -138,7 +231,7 @@ XLATensor::XLATensor(const at::Tensor& tensor, const Device& device, data()->requires_grad = requires_grad; } -XLATensor::XLATensor(std::shared_ptr xla_data, +XLATensor::XLATensor(xla::ComputationClient::DataPtr xla_data, bool requires_grad, c10::optional logical_element_type) : data_(std::make_shared(xla_data, Device(xla_data->device()), @@ -217,7 +310,7 @@ const Device& XLATensor::GetDevice() const { return data()->device; } xla::int64 XLATensor::GetUniqueId() const { return data()->unique_id; } -std::shared_ptr XLATensor::GetXlaData() { +xla::ComputationClient::DataPtr XLATensor::GetXlaData() { bool up_to_date = true; if (data()->view != nullptr) { View::IrNode ir_value_updated = data()->view->GetViewIrNode(); @@ -227,7 +320,7 @@ std::shared_ptr XLATensor::GetXlaData() { } } if (up_to_date) { - std::shared_ptr xla_data = CurrentXlaData(); + xla::ComputationClient::DataPtr xla_data = CurrentXlaData(); if (xla_data != nullptr) { return xla_data; } @@ -241,8 +334,7 @@ std::shared_ptr XLATensor::GetXlaData() { return data()->xla_data; } -std::shared_ptr XLATensor::CurrentXlaData() - const { +xla::ComputationClient::DataPtr XLATensor::CurrentXlaData() const { if (data()->xla_data != nullptr) { // When we set a new Node for a tensor, we leave the XLA data pointer alive, // as it is needed in order for the cached tensor apply operation to work. @@ -277,8 +369,7 @@ std::string XLATensor::DumpGraphNodeComputation() const { return hlo_text; } -void XLATensor::SetXlaData( - std::shared_ptr xla_data) { +void XLATensor::SetXlaData(xla::ComputationClient::DataPtr xla_data) { data()->view = nullptr; data()->xla_data = std::move(xla_data); data()->ir_value = ir::Value(); @@ -340,7 +431,7 @@ ir::Value XLATensor::GetIrValue() const { if (ir_value) { return ir_value; } - std::shared_ptr xla_data = CurrentXlaData(); + xla::ComputationClient::DataPtr xla_data = CurrentXlaData(); if (xla_data != nullptr) { // In case of tensor node, we do not clear the XLA data when we set the IR // node. This because we want further calls to GetIrValue() to fetch the @@ -461,7 +552,7 @@ std::vector XLATensor::GetTensors( // support getting handles and data might save a few pennies here. ApplyPendingGraph(tensors, /*apply_context=*/nullptr); - std::vector> tensors_data; + std::vector tensors_data; for (auto& tensor : *tensors) { if (!tensor.CurrentTensorData()) { tensors_data.push_back(tensor.GetXlaData()); @@ -503,7 +594,7 @@ std::vector XLATensor::GetTensors( std::vector XLATensor::CreateTensors( const std::vector& tensors, const std::vector& devices) { - std::vector> handles = + std::vector handles = CreateTensorsData(tensors, devices); std::vector xla_tensors; for (size_t i = 0; i < handles.size(); ++i) { @@ -514,8 +605,7 @@ std::vector XLATensor::CreateTensors( return xla_tensors; } -ir::Value XLATensor::CreateTensorNode( - std::shared_ptr data) { +ir::Value XLATensor::CreateTensorNode(xla::ComputationClient::DataPtr data) { return ir::ops::DeviceDataOp(std::move(data)); } @@ -563,6 +653,7 @@ XLATensor XLATensor::CreateFrom(ir::Value ir_value, const Device& device, } void XLATensor::ApplyPendingGraph() { + DeviceBarrier(GetDevice()); // This method is called to ensure that the tensor data is available on // device, so that a call to CurrentXlaData() returns a valid pointer. if (CurrentXlaData() == nullptr) { @@ -600,11 +691,17 @@ void XLATensor::ApplyPendingGraph() { XLATensor::SyncTensorCollection XLATensor::CollectSyncTensors( const std::vector& tensors) { + std::set device_set; + for (size_t i = 0; i < tensors.size(); ++i) { + device_set.insert(tensors[i].GetDevice()); + } + std::vector at_tensors; std::vector devices; std::vector at_tensor_index; SyncTensorCollection coll; coll.indices.reserve(tensors.size()); + coll.unlocker = LockDevices(device_set); for (size_t i = 0; i < tensors.size(); ++i) { if (tensors[i].CurrentXlaData() == nullptr) { ir::Value ir_value = tensors[i].CurrentIrValue(); @@ -624,7 +721,8 @@ XLATensor::SyncTensorCollection XLATensor::CollectSyncTensors( } } if (!at_tensors.empty()) { - std::vector> handles = + XLA_COUNTER("SyncTensorsToData", at_tensors.size()); + std::vector handles = CreateTensorsData(at_tensors, devices); for (size_t i = 0; i < handles.size(); ++i) { // If we are here, it means that the IR Value for the tensor is not @@ -638,45 +736,117 @@ XLATensor::SyncTensorCollection XLATensor::CollectSyncTensors( } bool XLATensor::TryRunCachedSync(std::vector* tensors, - const SyncTensorCollection& coll) { - const CachedComputation* cached_computation = - GetComputationCache()->Get(coll.hash); + SyncTensorCollection* coll) { + ComputationCache::TypePtr cached_computation = + GetComputationCache()->Get(coll->hash); if (cached_computation == nullptr) { return false; } xla::xla_util::Unique unique_device; std::vector roots; - roots.reserve(coll.indices.size()); - for (auto index : coll.indices) { + roots.reserve(coll->indices.size()); + for (auto index : coll->indices) { ir::Value ir_value = (*tensors)[index].CurrentIrValue(); roots.push_back(ir_value.node.get()); unique_device.set((*tensors)[index].GetDevice()); } - std::vector parameters_data; + std::vector parameters_data; for (auto node : ir::Util::ComputePostOrder(roots)) { const ir::ops::DeviceData* device_data = dynamic_cast(node); if (device_data != nullptr) { - parameters_data.push_back(device_data->data().get()); + parameters_data.push_back(device_data->data()); } } if (cached_computation->num_parameters != parameters_data.size()) { return false; } - - xla::ComputationClient::ExecuteComputationOptions options; - auto result = xla::ComputationClient::Get()->ExecuteComputation( - *cached_computation->computation, parameters_data, - unique_device->ToString(), options); - SetMulti(tensors, std::move(result), coll.indices); XLA_COUNTER("CachedSyncTensors", 1); + + ScheduleSyncTensorsGraph(tensors, coll, std::move(parameters_data), + unique_device->ToString(), + std::move(cached_computation)); return true; } +XLATensor::ComputationCache* XLATensor::GetComputationCache() { + static const size_t kMaxCacheSize = 128; + static ComputationCache* cache = new ComputationCache(kMaxCacheSize); + return cache; +} + +void XLATensor::ScheduleSyncTensorsGraph( + std::vector* tensors, SyncTensorCollection* coll, + std::vector parameters_data, + std::string device, ComputationCache::TypePtr cached_computation) { + // The xla::util::Cleanup is not copiable, so even though we can create a + // lambda my moving it inside the capture, channeling thorugh an + // std::function<> requires captures to be copiable (even though nothing would + // actually be copied). + struct Async { + Async(SyncTensorCollection* coll, + std::vector parameters_data, + std::string device, ComputationCache::TypePtr cached_computation) + : unlocker(std::move(coll->unlocker)), + parameters_data(std::move(parameters_data)), + device(std::move(device)), + cached_computation(std::move(cached_computation)) { + tensors_data.reserve(coll->indices.size()); + } + + std::vector unlocker; + std::vector parameters_data; + std::string device; + ComputationCache::TypePtr cached_computation; + std::vector tensors_data; + }; + std::shared_ptr async = std::make_shared( + coll, std::move(parameters_data), device, std::move(cached_computation)); + for (auto index : coll->indices) { + // The purpose of a tensor sync operation is to truncate the IR graph and + // materialize device data in place of IR graph, on selected tensors. + // But since operation will complete asynchronously, if a tensor does not + // already have device data, we need to install a placeholder. Since at this + // point we hold a lock on the device where the tensors reside (locks held + // within the coll structure, and moved into the async variable), any other + // operation trying to access the tensor's device data will have to wait + // until the asynchronous operation completes. + xla::ComputationClient::DataPtr xla_data = + (*tensors)[index].CurrentXlaData(); + if (xla_data == nullptr) { + xla_data = xla::ComputationClient::Get()->CreateDataPlaceholder( + device, (*tensors)[index].shape()); + (*tensors)[index].SetXlaData(xla_data); + } + async->tensors_data.emplace_back(std::move(xla_data)); + } + + auto syncfn = [async = std::move(async)]() { + xla::ComputationClient::ExecuteComputationOptions options; + try { + auto results = xla::ComputationClient::Get()->ExecuteComputation( + *async->cached_computation->computation, async->parameters_data, + async->device, options); + for (size_t i = 0; i < results.size(); ++i) { + async->tensors_data[i]->Swap(results[i].get()); + } + } catch (const std::exception& ex) { + xla::Status status = tensorflow::errors::Aborted(ex.what()); + for (auto& unlocker : async->unlocker) { + unlocker.SetStatus(status); + } + } + }; + + xla::xla_env::ScheduleIoClosure(std::move(syncfn)); +} + void XLATensor::SyncTensorsGraph(std::vector* tensors) { SyncTensorCollection coll = CollectSyncTensors(*tensors); - if (!coll.indices.empty() && !TryRunCachedSync(tensors, coll)) { + if (!coll.indices.empty() && !TryRunCachedSync(tensors, &coll)) { + XLA_COUNTER("UncachedSyncTensors", 1); + xla::xla_util::Unique unique_device; ir::LoweringContext lowering_ctx("SyncTensorsGraph"); for (auto index : coll.indices) { @@ -700,24 +870,25 @@ void XLATensor::SyncTensorsGraph(std::vector* tensors) { std::vector> computations = xla::ComputationClient::Get()->Compile(std::move(instances)); - - std::vector parameters_data = + std::vector parameters_data = lowering_ctx.GetParametersData(); - xla::ComputationClient::ExecuteComputationOptions options; - auto result = xla::ComputationClient::Get()->ExecuteComputation( - *computations.front(), parameters_data, unique_device->ToString(), - options); - SetMulti(tensors, std::move(result), coll.indices); - - GetComputationCache()->Add( - coll.hash, {std::move(computations.front()), parameters_data.size()}); - XLA_COUNTER("UncachedSyncTensors", 1); + ComputationCache::TypePtr cached_computation = GetComputationCache()->Add( + coll.hash, + std::make_shared(std::move(computations.front()), + parameters_data.size())); + + ScheduleSyncTensorsGraph(tensors, &coll, std::move(parameters_data), + unique_device->ToString(), + std::move(cached_computation)); } } -std::vector XLATensor::GetApplyOrder( +XLATensor::SyncTensorCollection XLATensor::CollectApplyGraphTensors( const std::vector& tensors) { SyncTensorCollection coll = CollectSyncTensors(tensors); + // The ApplyPendingGraph() only requires a barrier, as it never operates + // asynchronously, so we can release the device locks here. + coll.unlocker.clear(); // Order the tensors based on their device and unique ID, so that we try to // mazimize the chances of creating the same XLA computation, and hence // hitting the compilation cache. @@ -730,7 +901,7 @@ std::vector XLATensor::GetApplyOrder( } return tensors[i1].GetUniqueId() < tensors[i2].GetUniqueId(); }); - return std::move(coll.indices); + return coll; } bool XLATensor::RunCachedApply(std::vector* tensors, @@ -742,15 +913,15 @@ bool XLATensor::RunCachedApply(std::vector* tensors, for (size_t i = 0; i < tensors->size(); ++i) { uid_index_map[(*tensors)[i].GetUniqueId()] = i; } - std::vector> parameters; + std::vector> parameters; parameters.reserve(apply_context.devices.size()); for (auto& device_input_mapping : apply_context.input_mapping) { - std::vector device_parameters; + std::vector device_parameters; device_parameters.reserve(device_input_mapping.size()); for (auto uid : device_input_mapping) { auto it = uid_index_map.find(uid); if (it != uid_index_map.end()) { - const std::shared_ptr& xla_data = + const xla::ComputationClient::DataPtr& xla_data = (*tensors)[it->second].data()->xla_data; if (xla_data == nullptr) { // If we do not find real device data (we have a cached graph @@ -759,7 +930,7 @@ bool XLATensor::RunCachedApply(std::vector* tensors, XLA_COUNTER("NoTensorDataForUid", 1); return false; } - device_parameters.push_back(xla_data.get()); + device_parameters.push_back(xla_data); } else { // If we have not found the unique ID of the parameter which is // supposed to feed data to the computation, the pending graph context @@ -804,8 +975,7 @@ XLATensor::DataUidMap XLATensor::CreateDataUidMap( const std::vector& tensors) { DataUidMap data_uid_map(tensors.size()); for (size_t i = 0; i < tensors.size(); ++i) { - std::shared_ptr xla_data = - tensors[i].data()->xla_data; + xla::ComputationClient::DataPtr xla_data = tensors[i].data()->xla_data; if (xla_data != nullptr) { auto it_inserted = data_uid_map.emplace(xla_data->unique_id(), tensors[i].GetUniqueId()); @@ -848,10 +1018,10 @@ void XLATensor::ApplyPendingGraph(std::vector* tensors, std::vector index_mapping; }; - std::vector order = GetApplyOrder(*tensors); + SyncTensorCollection coll = CollectApplyGraphTensors(*tensors); std::vector uid_order; - uid_order.reserve(order.size()); - for (auto i : order) { + uid_order.reserve(coll.indices.size()); + for (auto i : coll.indices) { uid_order.push_back((*tensors)[i].GetUniqueId()); } DataUidMap data_uid_map; @@ -867,13 +1037,13 @@ void XLATensor::ApplyPendingGraph(std::vector* tensors, } std::map contexts_map; - for (auto i : order) { + for (auto i : coll.indices) { DeviceContext* device_context = &contexts_map[(*tensors)[i].GetDevice()]; device_context->index_mapping.push_back(i); } std::atomic unknown_params(0); - std::vector> parameters( + std::vector> parameters( contexts_map.size()); std::vector> input_mapping(contexts_map.size()); std::vector> index_mapping(contexts_map.size()); @@ -908,7 +1078,7 @@ void XLATensor::ApplyPendingGraph(std::vector* tensors, GetCompilationDevices(devices[index]), &shapes[index]}; - std::vector parameters_data = + std::vector parameters_data = device_context->lowering_ctx.GetParametersData(); if (apply_context != nullptr) { std::vector device_input_mapping; diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index d56597d68792..a94fffea48d7 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -6,6 +6,7 @@ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_client/cache.h" #include "tensorflow/compiler/xla/xla_client/computation_client.h" #include "tensorflow/compiler/xla/xla_client/util.h" #include "torch/csrc/autograd/variable.h" @@ -35,8 +36,7 @@ class XLATensor { static XLATensor Create(const at::Tensor& tensor, const Device& device, bool requires_grad); static XLATensor Create( - std::shared_ptr xla_data, - bool requires_grad, + xla::ComputationClient::DataPtr xla_data, bool requires_grad, c10::optional logical_element_type = c10::nullopt); static XLATensor Create( @@ -73,13 +73,13 @@ class XLATensor { // Fetches the XLA data behind the tensor. If the tensor has a graph defining // its current value, executes the graph and fetches the XLA data result. - std::shared_ptr GetXlaData(); + xla::ComputationClient::DataPtr GetXlaData(); // Fetches the current value of the XLA data, which can be missing (nullptr) // in case the tensor has a graph defining its current value, - std::shared_ptr CurrentXlaData() const; + xla::ComputationClient::DataPtr CurrentXlaData() const; - void SetXlaData(std::shared_ptr xla_data); + void SetXlaData(xla::ComputationClient::DataPtr xla_data); // Retrieves the current IR Node, or nullptr in case no active IR Node is // available. @@ -800,14 +800,26 @@ class XLATensor { struct SyncTensorCollection { std::vector indices; size_t hash = 0; + std::vector unlocker; }; + struct CachedComputation { + CachedComputation( + std::shared_ptr computation, + size_t num_parameters) + : computation(std::move(computation)), num_parameters(num_parameters) {} + + std::shared_ptr computation; + size_t num_parameters; + }; + + using ComputationCache = xla::util::Cache; + // This is the core XLA tensor data structure where all the tensor data is // held. The XLA tensor is nothing more than a shared pointer to a Data // object. struct Data { - Data(std::shared_ptr xla_data, - const Device& device, + Data(xla::ComputationClient::DataPtr xla_data, const Device& device, c10::optional logical_element_type) : xla_data(std::move(xla_data)), logical_element_type(logical_element_type), @@ -833,7 +845,7 @@ class XLATensor { ~Data(); - std::shared_ptr xla_data; + xla::ComputationClient::DataPtr xla_data; ir::Value ir_value; std::shared_ptr view; c10::optional logical_element_type; @@ -845,8 +857,7 @@ class XLATensor { }; XLATensor(const at::Tensor& tensor, const Device& device, bool requires_grad); - XLATensor(std::shared_ptr xla_data, - bool requires_grad, + XLATensor(xla::ComputationClient::DataPtr xla_data, bool requires_grad, c10::optional logical_element_type = c10::nullopt); XLATensor(ir::Value ir_value, const Device& device, c10::optional logical_element_type = c10::nullopt); @@ -921,23 +932,34 @@ class XLATensor { static bool RunCachedApply(std::vector* tensors, const ApplyContext& apply_context); + static ComputationCache* GetComputationCache(); + static SyncTensorCollection CollectSyncTensors( const std::vector& tensors); + // Schedules the execution of a sync tensors operation in background. The + // asynchronous operation will hold the device locks by capturing the ones + // present within the coll structure. + static void ScheduleSyncTensorsGraph( + std::vector* tensors, SyncTensorCollection* coll, + std::vector parameters_data, + std::string device, ComputationCache::TypePtr cached_computation); + static bool TryRunCachedSync(std::vector* tensors, - const SyncTensorCollection& coll); + SyncTensorCollection* coll); // Returns a permutation which represents an ordering by tensor device and // unique ID, of all the tensors which needs sync (the ones which have a graph // backing their value). The tensors which are already sync, will not be // returned within the permutation. If a tensor has at::Tensor data only, the // at::Tensor data will be uploaded to the device and the tensor will receive - // new XLA data. - static std::vector GetApplyOrder( + // new XLA data. This API will perform a barrier on device locks, and will not + // hold locks of partecipating devices (as ApplyPendingGraph() is never + // asynchronous). + static SyncTensorCollection CollectApplyGraphTensors( const std::vector& tensors); - static ir::Value CreateTensorNode( - std::shared_ptr data); + static ir::Value CreateTensorNode(xla::ComputationClient::DataPtr data); static xla::int64 GetNextTensorId();