Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/cpp/test_ir.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#include <gtest/gtest.h>

#include "ir.h"
#include "lowering_context.h"
#include "ops/ops.h"
#include "ops/scalar.h"

TEST(IrTest, TestScalar) {
TEST(IrTest, TestScalarCreate) {
torch_xla::ir::NodePtr scalar = torch_xla::ir::ops::ScalarOp(1.0, xla::F32);
ASSERT_TRUE(scalar != nullptr);
}
9 changes: 9 additions & 0 deletions third_party/xla_client/computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
namespace xla {
namespace {

ComputationClient* CreateClient() {
return ComputationClient::Create().ConsumeValueOrDie().release();
}

string GetTpuClusterConfigPath() {
string home_folder = sys_util::GetEnvString("HOME", ".");
return absl::StrCat(home_folder, "/", ".pytorch_tpu.conf");
Expand Down Expand Up @@ -161,6 +165,11 @@ int64 ComputationClient::GetDeviceOrdinal(const string& device) {
return std::stoi(device.substr(pos + 1));
}

ComputationClient* ComputationClient::Get() {
static ComputationClient* computation_client = CreateClient();
return computation_client;
}

metrics::Metric* ComputationClient::TransferToServerMetric() {
static metrics::Metric* metric =
new metrics::Metric("TransferToServerTime", metrics::MetricFnTime);
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla_client/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ class ComputationClient {
// after the last ':' character of the device string.
static int64 GetDeviceOrdinal(const string& device);

// Returns the ComputationClient singleton.
static ComputationClient* Get();

protected:
// Metrics common to all client intrfaces.
static metrics::Metric* TransferToServerMetric();
Expand Down
10 changes: 5 additions & 5 deletions torch_xla/csrc/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ void XlaModule::backward(const TensorBatchVector& grad_outputs) {
GetBackwardBuildOptions(inputs_.size()))
.computation;
xla::Shape result_shape = GetResultShape(computation, grad_outputs);
backward_computation_ = XlaGetClient()->Compile(
backward_computation_ = xla::ComputationClient::Get()->Compile(
std::move(computation), GetStringDevices(), &result_shape);
}
// Collect the computation client data vector.
Expand Down Expand Up @@ -346,7 +346,7 @@ XlaModule::TensorBatchVector XlaModule::RunFusedTrain(
xla::XlaComputation computation =
BuildFusedTrainComputation(forward_shapes);
xla::Shape result_shape = GetResultShape(computation, inputs);
forward_computation_ = XlaGetClient()->Compile(
forward_computation_ = xla::ComputationClient::Get()->Compile(
std::move(computation), GetStringDevices(), &result_shape);
}

Expand Down Expand Up @@ -499,7 +499,7 @@ XlaModule::TensorBatchVector XlaModule::RunUnfusedForward(

xla::Shape result_shape =
GetResultShape(forward_translation_result.computation, inputs);
forward_computation_ = XlaGetClient()->Compile(
forward_computation_ = xla::ComputationClient::Get()->Compile(
std::move(forward_translation_result.computation), GetStringDevices(),
&result_shape);
}
Expand Down Expand Up @@ -583,11 +583,11 @@ XlaModule::TensorBatchVector XlaModule::Execute(
exec_results;
if (inputs.size() == 1) {
xla::ComputationClient::ExecuteComputationOptions options;
exec_results.push_back(XlaGetClient()->ExecuteComputation(
exec_results.push_back(xla::ComputationClient::Get()->ExecuteComputation(
computation, inputs.front(), computation.devices()[0], options));
} else {
xla::ComputationClient::ExecuteReplicatedOptions options;
exec_results = XlaGetClient()->ExecuteReplicated(
exec_results = xla::ComputationClient::Get()->ExecuteReplicated(
computation, inputs, computation.devices(), options);
}
return CreateResultBatchVector(std::move(exec_results));
Expand Down
25 changes: 14 additions & 11 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ XLATensor::XLATensor(const torch::autograd::Variable& tensor,
tensor.sizes(),
XlaHelpers::MakeXlaPrimitiveType(tensor.type().scalarType()),
device.hw_type),
device, XlaGetClient()),
device, xla::ComputationClient::Get()),
device)),
requires_grad_(tensor.requires_grad()) {}

Expand Down Expand Up @@ -447,7 +447,7 @@ const at::Tensor& XLATensor::ToTensor() {
ApplyPendingGraph();

std::vector<xla::Literal> literals =
XlaGetClient()->TransferFromServer({GetXlaData()});
xla::ComputationClient::Get()->TransferFromServer({GetXlaData()});
tensor_data = std::make_shared<at::Tensor>(torch::autograd::make_variable(
MakeTensorFromXlaLiteral(literals.front()), RequiresGrad()));
SetTensorData(tensor_data);
Expand All @@ -470,7 +470,7 @@ std::vector<at::Tensor> XLATensor::GetTensors(
tensors_data.push_back(tensor->GetXlaData());
}
std::vector<xla::Literal> literals =
XlaGetClient()->TransferFromServer(tensors_data);
xla::ComputationClient::Get()->TransferFromServer(tensors_data);
std::vector<at::Tensor> results;
for (size_t i = 0; i < literals.size(); ++i) {
results.push_back(torch::autograd::make_variable(
Expand All @@ -495,7 +495,8 @@ std::vector<std::shared_ptr<XLATensor>> XLATensor::CreateTensors(
};
literal_device.emplace_back(std::move(converter), devices[i]);
}
auto handles = XlaGetClient()->TransferToServer(literal_device);
auto handles =
xla::ComputationClient::Get()->TransferToServer(literal_device);
std::vector<std::shared_ptr<XLATensor>> xla_tensors;
for (size_t i = 0; i < handles.size(); ++i) {
xla_tensors.push_back(
Expand Down Expand Up @@ -607,12 +608,12 @@ void XLATensor::ApplyPendingGraph() {
xla::XlaOp root = lowering_ctx.GetOutputOp(ir::Output(ir_node.get(), 0));
xla::XlaComputation computation =
lowering_ctx.Build(root).ConsumeValueOrDie();
auto compiled_computation = XlaGetClient()->Compile(
auto compiled_computation = xla::ComputationClient::Get()->Compile(
std::move(computation), {GetDevice().ToString()},
/*output_shape=*/nullptr);
xla::ComputationClient::ExecuteComputationOptions options;
options.explode_tuple = false;
auto results = XlaGetClient()->ExecuteComputation(
auto results = xla::ComputationClient::Get()->ExecuteComputation(
*compiled_computation, lowering_ctx.GetParametersData(),
compiled_computation->devices()[0], options);
XLA_CHECK_EQ(results.size(), 1);
Expand Down Expand Up @@ -695,7 +696,7 @@ bool XLATensor::RunCachedApply(
}

xla::ComputationClient::ExecuteParallelOptions options;
auto results = XlaGetClient()->ExecuteParallel(
auto results = xla::ComputationClient::Get()->ExecuteParallel(
xla::util::GetConstSharedPointers(apply_context.computations), parameters,
apply_context.devices, options);
size_t device_index = 0;
Expand Down Expand Up @@ -828,10 +829,10 @@ void XLATensor::ApplyPendingGraph(
std::vector<std::shared_ptr<xla::ComputationClient::Computation>>
computations;
if (!instances.empty()) {
computations = XlaGetClient()->Compile(std::move(instances));
computations = xla::ComputationClient::Get()->Compile(std::move(instances));

xla::ComputationClient::ExecuteParallelOptions options;
auto results = XlaGetClient()->ExecuteParallel(
auto results = xla::ComputationClient::Get()->ExecuteParallel(
xla::util::GetConstSharedPointers(computations), parameters, devices,
options);
auto context_iterator = contexts_map.begin();
Expand All @@ -856,12 +857,14 @@ void XLATensor::ApplyPendingGraph(

XLATensor::Device XLATensor::DeviceFromString(const std::string& device_spec) {
if (device_spec.empty()) {
const std::string default_device_spec = XlaGetClient()->GetDefaultDevice();
const std::string default_device_spec =
xla::ComputationClient::Get()->GetDefaultDevice();
XLA_CHECK(!default_device_spec.empty());
return DeviceFromString(default_device_spec);
}
if (device_spec[0] == ':') {
const std::string default_device_spec = XlaGetClient()->GetDefaultDevice();
const std::string default_device_spec =
xla::ComputationClient::Get()->GetDefaultDevice();
auto pos = default_device_spec.find(':');
XLA_CHECK_NE(pos, std::string::npos) << default_device_spec;
return DeviceFromString(default_device_spec.substr(0, pos) + device_spec);
Expand Down
9 changes: 0 additions & 9 deletions torch_xla/csrc/translator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
namespace torch_xla {
namespace {

xla::ComputationClient* CreateClient() {
return xla::ComputationClient::Create().ConsumeValueOrDie().release();
}

xla::XlaOp GetConstantOp(xla::XlaBuilder* builder,
const torch::jit::Node* node) {
auto value = toIValue(node->output()).value();
Expand Down Expand Up @@ -639,11 +635,6 @@ GetTranslationHandlers() {

} // namespace

xla::ComputationClient* XlaGetClient() {
static xla::ComputationClient* computation_client = CreateClient();
return computation_client;
}

XlaTranslator::XlaTranslator(
const std::shared_ptr<torch::jit::Graph>& graph,
const xla::PrecisionConfig::Precision conv_precision)
Expand Down
3 changes: 0 additions & 3 deletions torch_xla/csrc/translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include <string>

#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_client/computation_client.h"
#include "torch/csrc/jit/ir.h"

namespace torch_xla {
Expand Down Expand Up @@ -75,6 +74,4 @@ class XlaTranslator {
xla::PrecisionConfig::Precision conv_precision_;
};

xla::ComputationClient* XlaGetClient();

} // namespace torch_xla