From d6442054c16bdb6291813506b87045fa80b79945 Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Wed, 16 Jan 2019 14:23:34 -0800 Subject: [PATCH] Move the client-get into the computation client. --- test/cpp/test_ir.cpp | 3 ++- third_party/xla_client/computation_client.cc | 9 +++++++ third_party/xla_client/computation_client.h | 3 +++ torch_xla/csrc/module.cpp | 10 ++++---- torch_xla/csrc/tensor.cpp | 25 +++++++++++--------- torch_xla/csrc/translator.cpp | 9 ------- torch_xla/csrc/translator.h | 3 --- 7 files changed, 33 insertions(+), 29 deletions(-) diff --git a/test/cpp/test_ir.cpp b/test/cpp/test_ir.cpp index 99286692fed9..3b404870c5f7 100644 --- a/test/cpp/test_ir.cpp +++ b/test/cpp/test_ir.cpp @@ -1,10 +1,11 @@ #include #include "ir.h" +#include "lowering_context.h" #include "ops/ops.h" #include "ops/scalar.h" -TEST(IrTest, TestScalar) { +TEST(IrTest, TestScalarCreate) { torch_xla::ir::NodePtr scalar = torch_xla::ir::ops::ScalarOp(1.0, xla::F32); ASSERT_TRUE(scalar != nullptr); } diff --git a/third_party/xla_client/computation_client.cc b/third_party/xla_client/computation_client.cc index d4100cc359ca..2481d7c58dd7 100644 --- a/third_party/xla_client/computation_client.cc +++ b/third_party/xla_client/computation_client.cc @@ -17,6 +17,10 @@ namespace xla { namespace { +ComputationClient* CreateClient() { + return ComputationClient::Create().ConsumeValueOrDie().release(); +} + string GetTpuClusterConfigPath() { string home_folder = sys_util::GetEnvString("HOME", "."); return absl::StrCat(home_folder, "/", ".pytorch_tpu.conf"); @@ -161,6 +165,11 @@ int64 ComputationClient::GetDeviceOrdinal(const string& device) { return std::stoi(device.substr(pos + 1)); } +ComputationClient* ComputationClient::Get() { + static ComputationClient* computation_client = CreateClient(); + return computation_client; +} + metrics::Metric* ComputationClient::TransferToServerMetric() { static metrics::Metric* metric = new metrics::Metric("TransferToServerTime", metrics::MetricFnTime); diff --git a/third_party/xla_client/computation_client.h b/third_party/xla_client/computation_client.h index 5435f396961b..b56b512b0b2b 100644 --- a/third_party/xla_client/computation_client.h +++ b/third_party/xla_client/computation_client.h @@ -169,6 +169,9 @@ class ComputationClient { // after the last ':' character of the device string. static int64 GetDeviceOrdinal(const string& device); + // Returns the ComputationClient singleton. + static ComputationClient* Get(); + protected: // Metrics common to all client intrfaces. static metrics::Metric* TransferToServerMetric(); diff --git a/torch_xla/csrc/module.cpp b/torch_xla/csrc/module.cpp index 6a1bbfc43b19..6d12611ebde6 100644 --- a/torch_xla/csrc/module.cpp +++ b/torch_xla/csrc/module.cpp @@ -279,7 +279,7 @@ void XlaModule::backward(const TensorBatchVector& grad_outputs) { GetBackwardBuildOptions(inputs_.size())) .computation; xla::Shape result_shape = GetResultShape(computation, grad_outputs); - backward_computation_ = XlaGetClient()->Compile( + backward_computation_ = xla::ComputationClient::Get()->Compile( std::move(computation), GetStringDevices(), &result_shape); } // Collect the computation client data vector. @@ -346,7 +346,7 @@ XlaModule::TensorBatchVector XlaModule::RunFusedTrain( xla::XlaComputation computation = BuildFusedTrainComputation(forward_shapes); xla::Shape result_shape = GetResultShape(computation, inputs); - forward_computation_ = XlaGetClient()->Compile( + forward_computation_ = xla::ComputationClient::Get()->Compile( std::move(computation), GetStringDevices(), &result_shape); } @@ -499,7 +499,7 @@ XlaModule::TensorBatchVector XlaModule::RunUnfusedForward( xla::Shape result_shape = GetResultShape(forward_translation_result.computation, inputs); - forward_computation_ = XlaGetClient()->Compile( + forward_computation_ = xla::ComputationClient::Get()->Compile( std::move(forward_translation_result.computation), GetStringDevices(), &result_shape); } @@ -583,11 +583,11 @@ XlaModule::TensorBatchVector XlaModule::Execute( exec_results; if (inputs.size() == 1) { xla::ComputationClient::ExecuteComputationOptions options; - exec_results.push_back(XlaGetClient()->ExecuteComputation( + exec_results.push_back(xla::ComputationClient::Get()->ExecuteComputation( computation, inputs.front(), computation.devices()[0], options)); } else { xla::ComputationClient::ExecuteReplicatedOptions options; - exec_results = XlaGetClient()->ExecuteReplicated( + exec_results = xla::ComputationClient::Get()->ExecuteReplicated( computation, inputs, computation.devices(), options); } return CreateResultBatchVector(std::move(exec_results)); diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index f486b08790b6..fa75724d005c 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -315,7 +315,7 @@ XLATensor::XLATensor(const torch::autograd::Variable& tensor, tensor.sizes(), XlaHelpers::MakeXlaPrimitiveType(tensor.type().scalarType()), device.hw_type), - device, XlaGetClient()), + device, xla::ComputationClient::Get()), device)), requires_grad_(tensor.requires_grad()) {} @@ -447,7 +447,7 @@ const at::Tensor& XLATensor::ToTensor() { ApplyPendingGraph(); std::vector literals = - XlaGetClient()->TransferFromServer({GetXlaData()}); + xla::ComputationClient::Get()->TransferFromServer({GetXlaData()}); tensor_data = std::make_shared(torch::autograd::make_variable( MakeTensorFromXlaLiteral(literals.front()), RequiresGrad())); SetTensorData(tensor_data); @@ -470,7 +470,7 @@ std::vector XLATensor::GetTensors( tensors_data.push_back(tensor->GetXlaData()); } std::vector literals = - XlaGetClient()->TransferFromServer(tensors_data); + xla::ComputationClient::Get()->TransferFromServer(tensors_data); std::vector results; for (size_t i = 0; i < literals.size(); ++i) { results.push_back(torch::autograd::make_variable( @@ -495,7 +495,8 @@ std::vector> XLATensor::CreateTensors( }; literal_device.emplace_back(std::move(converter), devices[i]); } - auto handles = XlaGetClient()->TransferToServer(literal_device); + auto handles = + xla::ComputationClient::Get()->TransferToServer(literal_device); std::vector> xla_tensors; for (size_t i = 0; i < handles.size(); ++i) { xla_tensors.push_back( @@ -607,12 +608,12 @@ void XLATensor::ApplyPendingGraph() { xla::XlaOp root = lowering_ctx.GetOutputOp(ir::Output(ir_node.get(), 0)); xla::XlaComputation computation = lowering_ctx.Build(root).ConsumeValueOrDie(); - auto compiled_computation = XlaGetClient()->Compile( + auto compiled_computation = xla::ComputationClient::Get()->Compile( std::move(computation), {GetDevice().ToString()}, /*output_shape=*/nullptr); xla::ComputationClient::ExecuteComputationOptions options; options.explode_tuple = false; - auto results = XlaGetClient()->ExecuteComputation( + auto results = xla::ComputationClient::Get()->ExecuteComputation( *compiled_computation, lowering_ctx.GetParametersData(), compiled_computation->devices()[0], options); XLA_CHECK_EQ(results.size(), 1); @@ -695,7 +696,7 @@ bool XLATensor::RunCachedApply( } xla::ComputationClient::ExecuteParallelOptions options; - auto results = XlaGetClient()->ExecuteParallel( + auto results = xla::ComputationClient::Get()->ExecuteParallel( xla::util::GetConstSharedPointers(apply_context.computations), parameters, apply_context.devices, options); size_t device_index = 0; @@ -828,10 +829,10 @@ void XLATensor::ApplyPendingGraph( std::vector> computations; if (!instances.empty()) { - computations = XlaGetClient()->Compile(std::move(instances)); + computations = xla::ComputationClient::Get()->Compile(std::move(instances)); xla::ComputationClient::ExecuteParallelOptions options; - auto results = XlaGetClient()->ExecuteParallel( + auto results = xla::ComputationClient::Get()->ExecuteParallel( xla::util::GetConstSharedPointers(computations), parameters, devices, options); auto context_iterator = contexts_map.begin(); @@ -856,12 +857,14 @@ void XLATensor::ApplyPendingGraph( XLATensor::Device XLATensor::DeviceFromString(const std::string& device_spec) { if (device_spec.empty()) { - const std::string default_device_spec = XlaGetClient()->GetDefaultDevice(); + const std::string default_device_spec = + xla::ComputationClient::Get()->GetDefaultDevice(); XLA_CHECK(!default_device_spec.empty()); return DeviceFromString(default_device_spec); } if (device_spec[0] == ':') { - const std::string default_device_spec = XlaGetClient()->GetDefaultDevice(); + const std::string default_device_spec = + xla::ComputationClient::Get()->GetDefaultDevice(); auto pos = default_device_spec.find(':'); XLA_CHECK_NE(pos, std::string::npos) << default_device_spec; return DeviceFromString(default_device_spec.substr(0, pos) + device_spec); diff --git a/torch_xla/csrc/translator.cpp b/torch_xla/csrc/translator.cpp index b0e8e5611d2d..cac3a9d738eb 100644 --- a/torch_xla/csrc/translator.cpp +++ b/torch_xla/csrc/translator.cpp @@ -22,10 +22,6 @@ namespace torch_xla { namespace { -xla::ComputationClient* CreateClient() { - return xla::ComputationClient::Create().ConsumeValueOrDie().release(); -} - xla::XlaOp GetConstantOp(xla::XlaBuilder* builder, const torch::jit::Node* node) { auto value = toIValue(node->output()).value(); @@ -639,11 +635,6 @@ GetTranslationHandlers() { } // namespace -xla::ComputationClient* XlaGetClient() { - static xla::ComputationClient* computation_client = CreateClient(); - return computation_client; -} - XlaTranslator::XlaTranslator( const std::shared_ptr& graph, const xla::PrecisionConfig::Precision conv_precision) diff --git a/torch_xla/csrc/translator.h b/torch_xla/csrc/translator.h index 9d29f86c0e8c..bee86235b24f 100644 --- a/torch_xla/csrc/translator.h +++ b/torch_xla/csrc/translator.h @@ -3,7 +3,6 @@ #include #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_client/computation_client.h" #include "torch/csrc/jit/ir.h" namespace torch_xla { @@ -75,6 +74,4 @@ class XlaTranslator { xla::PrecisionConfig::Precision conv_precision_; }; -xla::ComputationClient* XlaGetClient(); - } // namespace torch_xla