Skip to content

Commit

Permalink
[XLA:Python] Add a prototype XRT client that runs XRT ops using the T…
Browse files Browse the repository at this point in the history
…ensorFlow remote eager protocol.

PiperOrigin-RevId: 241212924
  • Loading branch information
hawkinsp authored and tensorflower-gardener committed Mar 31, 2019
1 parent 4ac9871 commit 937c81c
Show file tree
Hide file tree
Showing 8 changed files with 430 additions and 6 deletions.
46 changes: 45 additions & 1 deletion tensorflow/compiler/xla/python/BUILD
Expand Up @@ -8,7 +8,10 @@ load("//tensorflow:tensorflow.bzl", "tf_pybind_extension")

py_library(
name = "xla_client",
srcs = ["xla_client.py"],
srcs = [
"xla_client.py",
"xrt.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [":xla_extension"],
Expand Down Expand Up @@ -60,6 +63,33 @@ cc_library(
],
)

cc_library(
name = "xrt",
srcs = ["xrt.cc"],
hdrs = ["xrt.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
":types",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xrt/client:xrt_client",
"//tensorflow/compiler/xrt/client:xrt_grpc_eager_client",
"//tensorflow/core:lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:optional",
"@pybind11",
],
)

tf_pybind_extension(
name = "xla_extension",
srcs = [
Expand All @@ -76,6 +106,7 @@ tf_pybind_extension(
module_name = "xla_extension",
deps = [
":types",
":xrt",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
Expand Down Expand Up @@ -114,3 +145,16 @@ tf_pybind_extension(
"//tensorflow/core:lib_internal_impl", # buildcleaner: keep
] + xla_python_default_plugins(),
)

# TODO(phawkins): enable this test.
# py_test(
# name = "xrt_test",
# srcs = ["xrt_test.py"],
# deps = [
# ":xla_client",
# "//third_party/py/numpy",
# "//tensorflow/compiler/jit:xla_cpu_device",
# "//tensorflow/compiler/xrt:xrt_server",
# "//tensorflow/python:client_testlib",
# ],
# )
3 changes: 3 additions & 0 deletions tensorflow/compiler/xla/python/xla.cc
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/python/local_client.h"
#include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/compiler/xla/python/xrt.h"
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape.h"
Expand Down Expand Up @@ -412,6 +413,8 @@ PYBIND11_MODULE(xla_extension, m) {
// TODO(phawkins): improve bindings for these types.
py::class_<ChannelHandle>(m, "ChannelHandle");
py::class_<PrecisionConfig>(m, "PrecisionConfig");

tensorflow::AddXrtSubmodule(&m);
}

} // namespace xla_python
Expand Down
178 changes: 178 additions & 0 deletions tensorflow/compiler/xla/python/xrt.cc
@@ -0,0 +1,178 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <memory>
#include <string>

#include "absl/memory/memory.h"
#include "absl/types/optional.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xrt/client/xrt_client.h"
#include "tensorflow/compiler/xrt/client/xrt_grpc_eager_client.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"

namespace tensorflow {
namespace {

namespace py = pybind11;

xla::StatusOr<std::shared_ptr<XrtTfClient>> GetTfClient(const string& address,
const string& worker) {
ClusterDef cluster_def;
JobDef* job = cluster_def.add_job();
job->set_name(worker);
(*job->mutable_tasks())[0] = address;
ChannelCreationFunction channel_func =
ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
TF_ASSIGN_OR_RETURN(std::shared_ptr<GrpcChannelCache> channel_cache,
GetGrpcChannelCache(cluster_def, channel_func));
return std::make_shared<XrtTfClient>(cluster_def, channel_cache);
}

// TODO(phawkins): This function won't produce a particularly good device
// assignment since it knows nothing about the hardware or its topology.
// It's here mostly as a placeholder until we do something smarter.
xla::StatusOr<xla::DeviceAssignment> AssignDevices(int num_replicas,
int num_computations) {
return xla::ComputationPlacer().AssignDevices(num_replicas, num_computations);
}

} // namespace

void AddXrtSubmodule(py::module* module) {
py::module m = module->def_submodule("xrt", "XRT backend");

m.def("AssignDevices", &AssignDevices,
"Computes a default device assignment.");

py::class_<XrtTfClient, std::shared_ptr<XrtTfClient>> xrt_tf_client(
m, "XrtTfClient");
m.def("GetTfClient", &GetTfClient, "Returns a TensorFlow client.");

py::class_<XrtTfContext::Options>(m, "XrtTfContextOptions")
.def(py::init<>())
.def_readwrite("async", &XrtTfContext::Options::async)
.def_readwrite("max_queue_size", &XrtTfContext::Options::max_queue_size);

py::class_<XrtTfContext, std::shared_ptr<XrtTfContext>>(m, "XrtTfContext")
.def_static("Create", &XrtTfContext::Create);

py::class_<XrtContext, std::shared_ptr<XrtContext>>(m, "XrtContext")
.def_static("Create", &XrtContext::Create)
.def("DeviceCount", &XrtContext::device_count)
.def_property_readonly("tf_device_ids", &XrtContext::tf_device_ids);

py::class_<XrtBuffer, std::shared_ptr<XrtBuffer>>(m, "XrtBuffer")
.def_static("FromLiteral", &XrtBuffer::FromLiteral)
.def("ToPython",
[](std::shared_ptr<XrtBuffer> buffer) -> xla::StatusOr<py::object> {
auto literal = absl::make_unique<xla::Literal>();
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(*literal, buffer->ToLiteral());
}
return xla::LiteralToPython(std::move(literal));
})
.def("Delete", &XrtBuffer::Delete)
.def("DestructureTuple", &XrtBuffer::DestructureTuple);

py::class_<XrtExecutable, std::shared_ptr<XrtExecutable>>(m, "XrtExecutable")
.def_static("Compile",
[](std::shared_ptr<XrtContext> context,
const std::string& hlo_module_proto_serialized,
const std::vector<xla::Shape>& argument_shapes,
const xla::Shape& result_shape,
const xla::DeviceAssignment& device_assignment) {
xla::HloModuleProto hlo_module_proto;
hlo_module_proto.ParsePartialFromString(
hlo_module_proto_serialized);
return XrtExecutable::Compile(context, hlo_module_proto,
argument_shapes, result_shape,
device_assignment);
})
.def("Execute", &XrtExecutable::Execute)
.def("ExecuteReplicated",
[](XrtExecutable& executable,
std::vector<std::vector<std::vector<std::shared_ptr<XrtBuffer>>>>
pyargs)
-> xla::StatusOr<
std::vector<std::vector<std::shared_ptr<XrtBuffer>>>> {
const xla::DeviceAssignment& device_assignment =
executable.device_assignment();
if (pyargs.size() != device_assignment.computation_count()) {
return xla::InvalidArgument(
"Outermost argument list must have one entry per "
"computation; "
"got %d args, device assignment has %d computations.",
pyargs.size(), device_assignment.computation_count());
}
std::vector<xla::Array2D<std::shared_ptr<XrtBuffer>>> args(
pyargs.size());
for (int i = 0; i < pyargs.size(); ++i) {
if (pyargs[i].size() != device_assignment.replica_count() ||
pyargs[i].empty()) {
return xla::InvalidArgument(
"Mismatch in number of replicas; got %d arguments, but "
"device assignment has %d replicas.",
pyargs[i].size(), device_assignment.replica_count());
}

int arg_count = pyargs[i][0].size();
args[i] = xla::Array2D<std::shared_ptr<XrtBuffer>>(
device_assignment.replica_count(), arg_count);
for (int j = 0; j < pyargs[i].size(); ++j) {
if (pyargs[i][j].size() != arg_count) {
return xla::InvalidArgument(
"Mismatched number of arguments to computation %d for "
"different replicas; %d vs %d arguments.",
i, arg_count, pyargs[i][j].size());
}
for (int k = 0; k < arg_count; ++k) {
args[i](j, k) = pyargs[i][j][k];
}
}
}

TF_ASSIGN_OR_RETURN(auto result,
executable.ExecuteReplicated(args));
std::vector<std::vector<std::shared_ptr<XrtBuffer>>> pyresult(
result.n1());
for (int i = 0; i < result.n1(); ++i) {
pyresult[i].resize(result.n2());
for (int j = 0; j < result.n2(); ++j) {
pyresult[i][j] = result(i, j);
}
}
return pyresult;
})
.def("Delete", &XrtExecutable::Delete)
.def("DeviceOrdinals", [](const XrtExecutable& executable) {
return std::vector<int>(executable.device_assignment().begin(),
executable.device_assignment().end());
});

m.doc() = "XRT backend plugin";
}

} // namespace tensorflow
27 changes: 27 additions & 0 deletions tensorflow/compiler/xla/python/xrt.h
@@ -0,0 +1,27 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_

#include "include/pybind11/pybind11.h"

namespace tensorflow {

void AddXrtSubmodule(pybind11::module* module);

} // namespace tensorflow

#endif // TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_

0 comments on commit 937c81c

Please sign in to comment.