diff --git a/tensorflow/compiler/xrt/client/BUILD b/tensorflow/compiler/xrt/client/BUILD index 072073c018d4d3..f926f4a28ffcf1 100644 --- a/tensorflow/compiler/xrt/client/BUILD +++ b/tensorflow/compiler/xrt/client/BUILD @@ -22,6 +22,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:worker_proto_cc", + "//tensorflow/core/distributed_runtime:call_options", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", "//tensorflow/core/distributed_runtime/rpc:grpc_client_cq_tag", "//tensorflow/core/distributed_runtime/rpc:grpc_state", @@ -42,6 +43,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/distributed_runtime:call_options", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -54,14 +56,38 @@ cc_library( ], ) +cc_library( + name = "xrt_client", + srcs = ["xrt_client.cc"], + hdrs = ["xrt_client.h"], + deps = [ + ":xrt_tf_client", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xrt:xrt_proto", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/protobuf/tpu:topology_proto_cc", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + tf_cc_test( name = "xrt_client_test", srcs = ["xrt_client_test.cc"], data = [":xrt_testlib_server"], deps = [ + ":xrt_client", ":xrt_grpc_eager_client", ":xrt_tf_client", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:eager_service_proto_cc", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/xrt/client/xrt_client.cc b/tensorflow/compiler/xrt/client/xrt_client.cc new file mode 100644 index 00000000000000..e6b0ad05d4e1b4 --- /dev/null +++ b/tensorflow/compiler/xrt/client/xrt_client.cc @@ -0,0 +1,603 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xrt/client/xrt_client.h" + +#include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xrt/client/xrt_tf_client.h" +#include "tensorflow/compiler/xrt/xrt.pb.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/tensor_coding.h" +#include "tensorflow/core/protobuf/tpu/topology.pb.h" + +namespace tensorflow { + +namespace { + +// Deserializes a TensorProto containing a scalar string value. +xla::StatusOr DeserializeTensorProtoAsString( + const TensorProto& proto) { + if (proto.dtype() != DT_STRING) { + return errors::InvalidArgument("Tensors must be of type DT_STRING, got ", + DataType_Name(proto.dtype())); + } + if (proto.tensor_shape().dim_size() != 0 || + proto.tensor_shape().unknown_rank()) { + return errors::InvalidArgument("String tensor must be a scalar, got ", + proto.tensor_shape().DebugString()); + } + if (proto.string_val_size() > 0) { + if (proto.string_val_size() != 1) { + return errors::InvalidArgument( + "Expected at most one string_val in TensorProto, got ", + proto.string_val_size()); + } + return proto.string_val(0); + } else { + std::string data; + port::DecodeStringList(proto.tensor_content(), &data, 1); + return data; + } +} + +// Deserializes a xla::Literal from a TensorProto. +xla::StatusOr DeserializeTensorProtoAsLiteral( + const TensorProto& proto) { + TF_ASSIGN_OR_RETURN(std::string data, DeserializeTensorProtoAsString(proto)); + xla::LiteralProto literal_proto; + literal_proto.ParsePartialFromString(data); + return xla::Literal::CreateFromProto(literal_proto); +} + +} // namespace + +XrtBuffer::XrtBuffer(XrtTensorHandle handle, xla::Shape shape) + : handle_(std::move(handle)), shape_(std::move(shape)) {} + +XrtBuffer::~XrtBuffer() { Delete(); } + +/*static*/ xla::StatusOr> XrtBuffer::FromLiteral( + const std::shared_ptr& context, int xrt_device_ordinal, + const xla::LiteralSlice& literal) { + xrt::XLAAllocation allocation; + *allocation.mutable_value() = literal.ToProto(); + + auto proto = absl::make_unique(); + proto->set_dtype(DT_STRING); + allocation.SerializeToString(proto->add_string_val()); + + if (xrt_device_ordinal < 0 || + xrt_device_ordinal >= context->tf_device_ids().size()) { + return errors::InvalidArgument("Invalid XRT device ordinal ", + xrt_device_ordinal); + } + int tf_device_id = context->tf_device_ids().at(xrt_device_ordinal); + XrtTensorHandle literal_handle = + context->tf_context()->SendTensor(std::move(proto), tf_device_id, + /*host_memory=*/true); + + XrtTensorHandle buffer_handle = std::move(context->tf_context()->EnqueueOp( + "XRTAllocate", {&literal_handle}, /*output_arity=*/1, /*attrs=*/{}, + tf_device_id)[0]); + + return std::make_shared(std::move(buffer_handle), literal.shape()); +} + +xla::StatusOr XrtBuffer::ToLiteral() const { + TF_RET_CHECK(handle_.valid()); + XrtTensorHandle literal_handle = std::move(handle_.context()->EnqueueOp( + "XRTReadLiteral", {&handle_}, /*output_arity=*/1, /*attrs=*/{}, + handle_.device_id())[0]); + + std::shared_ptr future = + handle_.context()->RecvTensor(literal_handle, DT_STRING, + /*host_memory=*/true); + + // Flush the queue to make sure the producers are dispatched before blocking + // on the future. + handle_.context()->FlushQueue(); + + TF_ASSIGN_OR_RETURN(RecvTensorResponse * response, future->Get()); + VLOG(10) << "ToLiteral received tensor " << response->DebugString(); + TF_RET_CHECK(!response->is_dead()); + return DeserializeTensorProtoAsLiteral(response->tensor()); +} + +void XrtBuffer::Delete() { + if (handle_.valid()) { + handle_.context()->EnqueueOp("XRTReleaseAllocationHandle", {&handle_}, + /*output_arity=*/0, + /*attrs=*/{}, handle_.device_id()); + handle_ = XrtTensorHandle(); + } +} + +xla::StatusOr>> +XrtBuffer::DestructureTuple() { + TF_RET_CHECK(shape_.IsTuple()); + std::vector> output; + output.reserve(shape_.tuple_shapes().size()); + for (int i = 0; i < shape_.tuple_shapes().size(); ++i) { + TensorProto index_proto; + index_proto.set_dtype(DT_INT32); + index_proto.mutable_tensor_shape()->add_dim()->set_size(1); + index_proto.add_int_val(i); + XrtTensorHandle index = + EnqueueConst(handle_.context().get(), handle_.device_id(), index_proto, + /*host_memory=*/true); + XrtTensorHandle sub = std::move( + handle_.context()->EnqueueOp("XRTSubTuple", {&handle_, &index}, + /*output_arity=*/1, + /*attrs=*/{}, handle_.device_id())[0]); + output.push_back( + std::make_shared(std::move(sub), shape_.tuple_shapes(i))); + } + return output; +} + +/*static*/ xla::StatusOr> XrtExecutable::Compile( + std::shared_ptr context, + const xla::HloModuleProto& hlo_module_proto, + const std::vector& argument_shapes, + const xla::Shape& result_shape, xla::DeviceAssignment device_assignment) { + if (device_assignment.replica_count() <= 0 || + device_assignment.computation_count() <= 0) { + return errors::InvalidArgument( + "Device assignment must be non-empty; got ", + device_assignment.replica_count(), " replicas and ", + device_assignment.computation_count(), " computations per replica."); + } + + // TODO(phawkins): add support for per-core argument and return shapes. + TF_RET_CHECK(device_assignment.computation_count() == 1) + << "Computation count != 1 not implemented"; + + xrt::XLAComputation computation; + computation.mutable_config()->set_num_replicas( + device_assignment.replica_count()); + computation.mutable_config()->set_num_cores_per_replica( + device_assignment.computation_count()); + + xrt::DeviceAssignment* xrt_assignment = + computation.mutable_config()->mutable_device_assignment(); + for (int computation = 0; computation < device_assignment.computation_count(); + ++computation) { + xrt::DeviceAssignment::ComputationDevice* xrt_devices = + xrt_assignment->add_computation_devices(); + for (int replica = 0; replica < device_assignment.replica_count(); + ++replica) { + int xrt_device_ordinal = device_assignment(replica, computation); + if (xrt_device_ordinal < 0 || + xrt_device_ordinal >= context->tf_device_ids().size()) { + return errors::InvalidArgument("Invalid device ordinal in device ", + "assignment: ", xrt_device_ordinal); + } + *xrt_devices->add_replica_devices() = + context->device_mesh_coordinates().at(xrt_device_ordinal); + } + } + + xla::ProgramShape program_shape; + for (const xla::Shape& shape : argument_shapes) { + xla::Shape* param_shape = program_shape.add_parameters(); + *param_shape = shape; + if (!xla::LayoutUtil::HasLayout(shape)) { + xla::LayoutUtil::SetToDefaultLayout(param_shape); + } + } + *program_shape.mutable_result() = result_shape; + if (!xla::LayoutUtil::HasLayout(result_shape)) { + xla::LayoutUtil::SetToDefaultLayout(program_shape.mutable_result()); + } + *computation.mutable_config()->mutable_program_shape() = + program_shape.ToProto(); + *computation.mutable_hlo_snapshot()->mutable_hlo()->mutable_hlo_module() = + hlo_module_proto; + + auto proto = absl::make_unique(); + proto->set_dtype(DT_STRING); + computation.SerializeToString(proto->add_string_val()); + + int xrt_device_ordinal_for_compilation = device_assignment(0, 0); + int tf_device_id = + context->tf_device_ids().at(xrt_device_ordinal_for_compilation); + XrtTensorHandle computation_handle = + context->tf_context()->SendTensor(std::move(proto), tf_device_id, + /*host_memory=*/true); + + XrtTensorHandle executable_handle = + std::move(context->tf_context()->EnqueueOp( + "XRTCompile", {&computation_handle}, /*output_arity=*/2, /*attrs=*/{}, + tf_device_id)[0]); + + if (device_assignment.num_elements() > 1) { + string wire_id = XrtGetUniqueWireID(); + int recv_tf_device_id = context->tf_context()->cpu_device_id(); + EnqueueSend(context->tf_context().get(), executable_handle, DT_INT64, + recv_tf_device_id, wire_id, /*host_memory=*/true); + executable_handle = + EnqueueRecv(context->tf_context().get(), DT_INT64, tf_device_id, + recv_tf_device_id, wire_id, /*host_memory=*/true); + } + + return std::make_shared( + std::move(context), std::move(executable_handle), program_shape, + std::move(device_assignment)); +} + +XrtExecutable::XrtExecutable(std::shared_ptr context, + XrtTensorHandle handle, xla::ProgramShape shape, + xla::DeviceAssignment device_assignment) + : context_(std::move(context)), + handle_(std::move(handle)), + shape_(std::move(shape)), + device_assignment_(std::move(device_assignment)) {} + +XrtExecutable::~XrtExecutable() { Delete(); } + +void XrtExecutable::Delete() { + if (handle_.valid()) { + handle_.context()->EnqueueOp("XRTReleaseCompilationHandle", {&handle_}, + /*output_arity=*/0, + /*attrs=*/{}, handle_.device_id()); + handle_ = XrtTensorHandle(); + } +} + +xla::StatusOr> XrtExecutable::Execute( + const std::vector>& args) { + TF_RET_CHECK(device_assignment_.replica_count() == 1 && + device_assignment_.computation_count() == 1) + << device_assignment_.ToString(); + int xrt_device_ordinal = device_assignment_(0, 0); + int tf_device_id = context_->tf_device_ids().at(xrt_device_ordinal); + + TensorProto config_proto; + config_proto.set_dtype(DT_STRING); + config_proto.add_string_val(); + XrtTensorHandle execution_config_handle = + EnqueueConst(handle_.context().get(), tf_device_id, config_proto, + /*host_memory=*/true); + + protobuf::Map attrs; + attrs["Ninputs"] = MakeAttrValue(args.size()); + + std::vector inputs; + inputs.reserve(args.size() + 2); + inputs.push_back(&handle_); + inputs.push_back(&execution_config_handle); + for (const std::shared_ptr& arg : args) { + if (arg->handle().device_id() != tf_device_id) { + return errors::InvalidArgument( + "Input buffer to Execute() is not on the device for which the " + "computation was compiled. Target device is ", + tf_device_id, ", buffer is on device ", arg->handle().device_id()); + } + inputs.push_back(&arg->handle()); + } + + XrtTensorHandle result_handle = std::move(handle_.context()->EnqueueOp( + "XRTExecute", inputs, /*output_arity=*/1, attrs, tf_device_id)[0]); + + return std::make_shared(std::move(result_handle), shape_.result()); +} + +xla::StatusOr>> +XrtExecutable::ExecuteReplicated( + absl::Span>> args) { + if (args.size() != device_assignment_.computation_count()) { + return errors::InvalidArgument( + "Mismatched number of computation per replica between executable and " + "arguments. Expected computations_per_replica=", + device_assignment_.computation_count(), + "; got computations_per_replica=", args.size()); + } + + for (int computation = 0; + computation < device_assignment_.computation_count(); ++computation) { + if (args[computation].n1() != device_assignment_.replica_count()) { + return errors::InvalidArgument( + "Mismatched number of replicas between executable and arguments for " + " computation ", + computation, + ". Expected replicas=", device_assignment_.replica_count(), + "; got replicas=", args[computation].n1()); + } + for (int replica = 0; replica < device_assignment_.replica_count(); + ++replica) { + int xrt_device_ordinal = device_assignment_(replica, computation); + int tf_device_id = context_->tf_device_ids().at(xrt_device_ordinal); + for (int arg = 0; arg < args[computation].n2(); ++arg) { + const std::shared_ptr& buffer = + args[computation](replica, arg); + if (buffer->handle().device_id() != tf_device_id) { + return errors::InvalidArgument( + "Input buffer to ExecuteReplicated() is not on the device for " + "which the computation was compiled. Target device is ", + tf_device_id, ", buffer is on device ", + buffer->handle().device_id()); + } + } + } + } + + std::vector input_arity; + input_arity.reserve(args.size()); + for (const auto& arg : args) { + input_arity.push_back(arg.n2()); + } + TF_ASSIGN_OR_RETURN(string exec_fn, context_->GetExecuteReplicatedFunction( + input_arity, device_assignment_)); + + std::vector input_types; + std::vector inputs; + inputs.push_back(&handle_); + input_types.push_back(DT_INT64); + + std::vector execution_config_handles( + device_assignment_.computation_count()); + int tf_cpu_device_id = context_->tf_context()->cpu_device_id(); + for (int j = 0; j < device_assignment_.computation_count(); ++j) { + TensorProto config_proto; + config_proto.set_dtype(DT_STRING); + xrt::XRTExecutionConfig config; + config.set_core_index_in_replica(j); + config_proto.add_string_val(config.SerializeAsString()); + execution_config_handles[j] = EnqueueConst(context_->tf_context().get(), + tf_cpu_device_id, config_proto, + /*host_memory=*/true); + inputs.push_back(&execution_config_handles[j]); + input_types.push_back(DT_STRING); + } + + for (int i = 0; i < device_assignment_.replica_count(); ++i) { + for (int j = 0; j < device_assignment_.computation_count(); ++j) { + for (int k = 0; k < args[j].n2(); ++k) { + inputs.push_back(&args[j](i, j)->handle()); + input_types.push_back(DT_INT64); + } + } + } + + // Run all the XRTExecute ops in parallel using a multi-device function. + // We do this for two reasons: + // a) we need the operators to run in parallel, but without async mode enabled + // they might not. + // b) we need the operators to all be issued as part of the same + // EnqueueRequest batch, otherwise we will deadlock. + // TODO(phawkins): It would be even better to enable async mode, when its + // error semantics have been improved. + std::vector output_types(device_assignment_.num_elements(), + DT_INT64); + std::vector outputs = context_->tf_context()->EnqueueOp( + exec_fn, inputs, /*output_arity=*/output_types.size(), /*attrs=*/{}, + tf_cpu_device_id); + + xla::Array2D> results( + device_assignment_.computation_count(), + device_assignment_.replica_count()); + int output_num = 0; + for (int i = 0; i < device_assignment_.computation_count(); ++i) { + for (int j = 0; j < device_assignment_.replica_count(); ++j) { + int xrt_device_ordinal = device_assignment_(j, i); // NB. different order + int tf_device_id = context_->tf_device_ids().at(xrt_device_ordinal); + + // EnqueueOp doesn't know about multidevice functions, so it will assume + // that the outputs are on the CPU. Override the device IDs it assigned; + // we know better. + outputs[output_num].set_device_id(tf_device_id); + + // TODO(phawkins): use a per-core result shape here. + results(i, j) = std::make_shared( + std::move(outputs[output_num]), shape_.result()); + ++output_num; + } + } + return results; +} + +/*static*/ xla::StatusOr> XrtContext::Create( + std::shared_ptr tf_context, string device_type) { + auto context = std::make_shared(tf_context, device_type); + if (context->tf_device_ids().empty()) { + return errors::NotFound("No accelerator devices of type ", device_type, + " are present."); + } + if (device_type == "TPU") { + TF_RETURN_IF_ERROR(context->InitializeTPU()); + } else { + // Fill in a dummy topology mapping for CPU/GPU. + for (int i = 0; i < context->tf_device_ids().size(); ++i) { + context->device_mesh_coordinates_.push_back({}); + context->device_mesh_coordinates_.back().add_value(i); + } + } + return context; +} + +XrtContext::XrtContext(std::shared_ptr tf_context, + string device_type) + : tf_context_(std::move(tf_context)), device_type_(std::move(device_type)) { + for (int i = 0; i < tf_context_->devices().size(); ++i) { + const DeviceAttributes& device = tf_context_->devices()[i]; + VLOG(2) << "Device: " << i << ": " << device.DebugString(); + if (device.device_type() == device_type_) { + tf_device_ids_.push_back(i); + VLOG(1) << "Accelerator device " << i << ": " << device.name(); + } + } +} + +int XrtContext::device_count() const { return tf_device_ids_.size(); } + +static Status RegisterTPUInitializeFunction(XrtTfContext* context) { + FunctionDef fdef; + OpDef* opdef = fdef.mutable_signature(); + opdef->set_name("TPUInitFunc"); + OpDef::ArgDef* outdef = opdef->add_output_arg(); + outdef->set_name("topology"); + outdef->set_type(DT_STRING); + + NodeDef* ndef = fdef.add_node_def(); + ndef->set_name("n"); + ndef->set_op("ConfigureDistributedTPU"); + + (*fdef.mutable_ret())["topology"] = "n:topology"; + + Status status = context->RegisterFunction(fdef); + VLOG(10) << "RegisterTPUInitializeFunction returned " << status; + return status; +} + +Status XrtContext::InitializeTPU() { + LOG(INFO) << "Initializing TPU devices."; + TF_RETURN_IF_ERROR(RegisterTPUInitializeFunction(tf_context_.get())); + + TensorProto index_proto; + index_proto.set_dtype(DT_INT32); + index_proto.add_int_val(0); + XrtTensorHandle device_ordinal = EnqueueConst( + tf_context_.get(), /*device_id=*/tf_context_->cpu_device_id(), + index_proto, /*host_memory=*/false); + + protobuf::Map attrs; + attrs["f"].mutable_func()->set_name("TPUInitFunc"); + attrs["Tin"].mutable_list(); + attrs["Tout"].mutable_list()->add_type(DT_STRING); + XrtTensorHandle t = std::move( + tf_context_->EnqueueOp("TPUPartitionedCall", {&device_ordinal}, + /*output_arity=*/1, + /*attrs=*/attrs, tf_context_->cpu_device_id())[0]); + + auto result = tf_context_->RecvTensor(t, DT_STRING, /*host_memory=*/false); + TF_ASSIGN_OR_RETURN(RecvTensorResponse * response, result->Get()); + VLOG(10) << "TPU topology " << response->DebugString(); + + TF_ASSIGN_OR_RETURN(std::string data, + DeserializeTensorProtoAsString(response->tensor())); + + tpu::TopologyProto tpu_topology; + tpu_topology.ParsePartialFromString(data); + VLOG(4) << "TPU topology:\n" << tpu_topology.DebugString(); + + TF_RET_CHECK(tpu_topology.num_tasks() == 1) << tpu_topology.DebugString(); + TF_RET_CHECK(tpu_topology.num_tpu_devices_per_task() == tf_device_ids_.size()) + << tpu_topology.DebugString() << " " << tf_device_ids_.size(); + + const int mesh_rank = tpu_topology.mesh_shape_size(); + TF_RET_CHECK(tpu_topology.device_coordinates_size() == + tf_device_ids_.size() * mesh_rank); + + for (int i = 0; i < tf_device_ids_.size(); ++i) { + device_mesh_coordinates_.push_back({}); + auto& coords = device_mesh_coordinates_.back(); + for (int j = 0; j < mesh_rank; ++j) { + coords.add_value(tpu_topology.device_coordinates(i * mesh_rank + j)); + } + } + + LOG(INFO) << "TPU initialization succeeded."; + return Status::OK(); +} + +XrtContext::ExecuteReplicatedKey::ExecuteReplicatedKey( + absl::Span input_arity, xla::DeviceAssignment device_assignment) + : input_arity(input_arity.begin(), input_arity.end()), + device_assignment(std::move(device_assignment)) {} + +bool XrtContext::ExecuteReplicatedKey::operator==( + const ExecuteReplicatedKey& other) const { + return input_arity == other.input_arity && + device_assignment == other.device_assignment; +} + +xla::StatusOr XrtContext::GetExecuteReplicatedFunction( + absl::Span input_arity, + const xla::DeviceAssignment& device_assignment) { + ExecuteReplicatedKey key(input_arity, device_assignment); + + absl::MutexLock lock(&mu_); + auto it = replicated_fns_.find(key); + if (it != replicated_fns_.end()) { + return it->second; + } + + string name = absl::StrCat("ExecuteReplicated_", replicated_fns_.size()); + + FunctionDef fdef; + OpDef* opdef = fdef.mutable_signature(); + opdef->set_name(name); + OpDef::ArgDef* execution_handle = opdef->add_input_arg(); + execution_handle->set_name("execution_handle"); + execution_handle->set_type(DT_INT64); + + TF_RET_CHECK(device_assignment.computation_count() == input_arity.size()); + + std::vector execution_configs; + execution_configs.reserve(device_assignment.computation_count()); + for (int j = 0; j < device_assignment.computation_count(); ++j) { + OpDef::ArgDef* execution_config = opdef->add_input_arg(); + execution_config->set_name(absl::StrCat("execution_config_computation", j)); + execution_config->set_type(DT_STRING); + execution_configs.push_back(execution_config); + } + + for (int i = 0; i < device_assignment.replica_count(); ++i) { + for (int j = 0; j < device_assignment.computation_count(); ++j) { + NodeDef* ndef = fdef.add_node_def(); + ndef->set_name(absl::StrFormat("execute_replica%d_computation%d", i, j)); + ndef->set_op("XRTExecute"); + (*ndef->mutable_attr())["Ninputs"] = MakeAttrValue(input_arity[j]); + ndef->add_input(execution_handle->name()); + ndef->add_input(execution_configs[j]->name()); + int tf_device_id = tf_device_ids_.at(device_assignment(i, j)); + ndef->set_device(tf_context_->devices().at(tf_device_id).name()); + + for (int k = 0; k < input_arity[j]; ++k) { + OpDef::ArgDef* arg = opdef->add_input_arg(); + arg->set_name( + absl::StrFormat("in_replica%d_computation%d_arg%d", i, j, k)); + arg->set_type(DT_INT64); + + ndef->add_input(arg->name()); + } + OpDef::ArgDef* ret = opdef->add_output_arg(); + ret->set_name(absl::StrFormat("out_replica%d_computation%d", i, j)); + ret->set_type(DT_INT64); + + (*fdef.mutable_ret())[ret->name()] = + absl::StrCat(ndef->name(), ":output_handle"); + } + } + + VLOG(10) << fdef.DebugString(); + + Status status = tf_context_->RegisterFunction(fdef); + VLOG(4) << "GetExecuteReplicatedFunction returned " << status; + if (!status.ok()) return status; + + replicated_fns_[key] = name; + return name; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/client/xrt_client.h b/tensorflow/compiler/xrt/client/xrt_client.h new file mode 100644 index 00000000000000..d8db2304563db9 --- /dev/null +++ b/tensorflow/compiler/xrt/client/xrt_client.h @@ -0,0 +1,244 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file contains a C++ client for XRT, that communicates with a remote +// TensorFlow Eager server over gRPC. +// +// This client is a prototype and its API is not stable yet. +// +// TODO(phawkins): add support for multi-host configurations. +// * currently the API names accelerator devices using a flat space of device +// ordinals, with no particular meaning to the device ordinals. The plan is to +// instead to use the linearized device topology coordinates as device +// ordinals. + +#ifndef TENSORFLOW_COMPILER_XRT_CLIENT_XRT_CLIENT_H_ +#define TENSORFLOW_COMPILER_XRT_CLIENT_XRT_CLIENT_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xrt/client/xrt_tf_client.h" +#include "tensorflow/compiler/xrt/xrt.pb.h" +#include "tensorflow/core/protobuf/tpu/topology.pb.h" + +namespace tensorflow { + +class XrtContext; + +// RAII class that holds ownership of an XRT buffer. +class XrtBuffer { + public: + // Builds a new XrtBuffer from an XLA literal, copying the buffer to the + // remote host. + static xla::StatusOr> FromLiteral( + const std::shared_ptr& context, int xrt_device_ordinal, + const xla::LiteralSlice& literal); + + // Converts an XrtBuffer to an XLA literal, copying the buffer from the remote + // host. Blocks until the buffer is available. + xla::StatusOr ToLiteral() const; + + // Deletes the remote buffer. + void Delete(); + + // Destructures a tuple-shaped buffer into its constituent pieces. + xla::StatusOr>> DestructureTuple(); + + // TODO(phawkins): add a static method for building tuples of buffers. + + // TODO(phawkins): add a mechanism for converting XrtBuffers into remote + // tensors and vice-versa for TF interoperability. + + XrtBuffer() = default; + XrtBuffer(XrtTensorHandle handle, xla::Shape shape); + ~XrtBuffer(); // Calls Delete(). + + // A buffer reference is moveable but not copyable. + XrtBuffer(const XrtBuffer&) = delete; + XrtBuffer(XrtBuffer&&) = default; + XrtBuffer& operator=(const XrtBuffer&) = delete; + XrtBuffer& operator=(XrtBuffer&&) = default; + + const XrtTensorHandle& handle() const { return handle_; } + + private: + // Tensor that contains the XRT allocation ID. + XrtTensorHandle handle_; + xla::Shape shape_; +}; + +// RAII class that holds ownership of an XRT executable. +class XrtExecutable { + public: + // Constructs an XrtExecutable by compiling a program. + // `xrt_device_ordinal` must be the ordinal of a device known to XrtContext + // on which the compile operator should be placed. + // `hlo_module_proto` is the serialized HLO program to compile. + // `argument_shapes` and `result_shape` describe the shapes of the + // arguments/result and their layout. + // `device_assignment` is the set of devices to which compilation should be + // targeted. The device numbers in the device assignment are the XRT device + // ordinals. + // TODO(phawkins): device assignments with more than one computation per + // replica do not work yet, even though the API appears to support them. + static xla::StatusOr> Compile( + std::shared_ptr context, + const xla::HloModuleProto& hlo_module_proto, + const std::vector& argument_shapes, + const xla::Shape& result_shape, xla::DeviceAssignment device_assignment); + + explicit XrtExecutable(std::shared_ptr context, + XrtTensorHandle handles, xla::ProgramShape shape, + xla::DeviceAssignment device_assignment); + ~XrtExecutable(); // Calls Delete(). + + // Deletes the XrtExecutable. + void Delete(); + + // Runs the executable. Simplified API without replication or model + // parallelism. + xla::StatusOr> Execute( + const std::vector>& args); + + // General API that runs replicated, model-parallel computations. + // + // Arguments are indexed by [computation][replica][arg]. Since each + // computation may have a different arity, we use a Span to represent + // a possibly ragged array. + // + // Return values are indexed by [computation][replica]. XLA computations + // always have exactly one return value, so there is no possibility of + // raggedness. + xla::StatusOr>> ExecuteReplicated( + absl::Span>> args); + + // Moveable but not copyable. + XrtExecutable(const XrtExecutable&) = delete; + XrtExecutable(XrtExecutable&&) = default; + XrtExecutable& operator=(const XrtExecutable&) = delete; + XrtExecutable& operator=(XrtExecutable&&) = default; + + const xla::DeviceAssignment& device_assignment() const { + return device_assignment_; + } + + private: + std::shared_ptr context_; + + // A copy of the executable's handle in host memory. If the computation is + // unreplicated, this lives on the target device. If the computation is + // replicated, this lives on the CPU device. + XrtTensorHandle handle_; + xla::ProgramShape shape_; + + // The TF device ordinal on which this handle was compiled and on which it + // should be deleted. + xla::DeviceAssignment device_assignment_; +}; + +// Manages an XRT session. +// +// The XrtTfClient/XrtTfContext classes wrap the TensorFlow API more directly, +// without any XRT-specific knowledge. The higher level XrtClient +// adds XRT-specific functionality on top. +// +// It is intended that all clients talking to the same XRT session use the same +// XrtContext and that objects such as buffers and executables must not be +// shared between XrtContexts. However, clients may run non-XRT TensorFlow ops +// using the XrtTfContext that underlies an XrtContext. +// +// TODO(phawkins): Currently this code only supports a single remote host; each +// XrtContext communicates via a single XrtTfContext. The plan is to support +// multihost configurations (e.g., TPU pods) in the future, in which case +// XrtContext will be extended to have one XrtTfContext per remote host. +// +// TODO(phawkins): This API is intended to be thread-safe, but this is untested. +class XrtContext { + public: + // Creates an XrtContext. Fails if no accelerators of 'device_type' are found. + static xla::StatusOr> Create( + std::shared_ptr tf_context, string device_type); + + // Use Create() instead. + XrtContext(std::shared_ptr tf_context, string device_type); + + // Returns the number of accelerator devices of 'device_type'. + int device_count() const; + + const std::shared_ptr& tf_context() const { + return tf_context_; + } + const std::vector& tf_device_ids() const { return tf_device_ids_; } + + const std::vector< + xrt::DeviceAssignment::ComputationDevice::DeviceMeshCoordinates>& + device_mesh_coordinates() const { + return device_mesh_coordinates_; + } + + private: + friend class XrtExecutable; + + const std::shared_ptr tf_context_; + const string device_type_; // Type of accelerator device to use (e.g., TPU) + + // Initializes TPU devices. Synchronous; called by Create(). + Status InitializeTPU(); + + // IDs of devices of type `device_type_` in `tf_context_`. + std::vector tf_device_ids_; + + // Device coordinates of each device, indexed by XRT device ordinal. + std::vector + device_mesh_coordinates_; + + // Returns the name of a function that launches a replicated computation + // with input arity `input_arity` and device assignment `device_assignment`. + xla::StatusOr GetExecuteReplicatedFunction( + absl::Span input_arity, + const xla::DeviceAssignment& device_assignment); + + struct ExecuteReplicatedKey { + ExecuteReplicatedKey(absl::Span input_arity, + xla::DeviceAssignment device_assignment); + std::vector input_arity; + xla::DeviceAssignment device_assignment; + bool operator==(const ExecuteReplicatedKey& other) const; + }; + template + friend H AbslHashValue(H h, const ExecuteReplicatedKey& key); + + absl::Mutex mu_; + absl::flat_hash_map replicated_fns_ + GUARDED_BY(mu_); +}; + +template +H AbslHashValue(H h, const XrtContext::ExecuteReplicatedKey& key) { + h = H::combine_contiguous(std::move(h), key.input_arity.data(), + key.input_arity.size()); + return H::combine_contiguous(std::move(h), key.device_assignment.data(), + key.device_assignment.num_elements()); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_XRT_CLIENT_XRT_CLIENT_H_ diff --git a/tensorflow/compiler/xrt/client/xrt_client_test.cc b/tensorflow/compiler/xrt/client/xrt_client_test.cc index 0fd13b9f745d89..66cda9f3bb9f6e 100644 --- a/tensorflow/compiler/xrt/client/xrt_client_test.cc +++ b/tensorflow/compiler/xrt/client/xrt_client_test.cc @@ -13,8 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xrt/client/xrt_client.h" + #include +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xrt/client/xrt_grpc_eager_client.h" @@ -51,6 +55,9 @@ class XrtClientTest : public ::testing::Test { job->set_name("localhost"); (*job->mutable_tasks())[0] = cluster_->targets()[0]; } + + xla::StatusOr> MakeContext(); + std::unique_ptr cluster_; ClusterDef cluster_def_; }; @@ -136,4 +143,162 @@ TEST_F(XrtClientTest, XrtTfClientWorks) { EXPECT_EQ(out[1], 50); } +xla::StatusOr> XrtClientTest::MakeContext() { + ChannelCreationFunction channel_func = + ConvertToChannelCreationFunction(NewHostPortGrpcChannel); + TF_ASSIGN_OR_RETURN(std::shared_ptr channel_cache, + GetGrpcChannelCache(cluster_def_, channel_func)); + + auto client = std::make_shared(cluster_def_, channel_cache); + TF_ASSIGN_OR_RETURN( + std::shared_ptr tf_context, + XrtTfContext::Create(XrtTfContext::Options(), client, /*job=*/"localhost", + /*task=*/0)); + + TF_ASSIGN_OR_RETURN(auto context, XrtContext::Create(tf_context, "XLA_CPU")); + + // There should be exactly one XLA_CPU device. + TF_RET_CHECK(context->device_count() == 1); + return context; +} + +// Tests that we can use the XRT client to perform some simple operations. +TEST_F(XrtClientTest, XrtClientWorks) { + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr context, MakeContext()); + + ASSERT_TRUE(context->tf_context() != nullptr); + + EXPECT_EQ(context->tf_device_ids().size(), 1); + + ASSERT_EQ(context->device_mesh_coordinates().size(), 1); + ASSERT_EQ(context->device_mesh_coordinates()[0].value_size(), 1); + EXPECT_EQ(context->device_mesh_coordinates()[0].value(0), 0); + + // Tests sending a literal to and from the device. + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {3, 4, 5}); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal a, + xla::LiteralUtil::CreateRandomLiteral( + shape, + /*mean=*/7.0, /*stddev=*/13.5)); + TF_ASSERT_OK_AND_ASSIGN(auto buffer, XrtBuffer::FromLiteral(context, 0, a)); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal b, buffer->ToLiteral()); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(a, b)); + + // Run a simple computation, fetch its output, and check it is what we expect. + auto build_computation = [&]() { + xla::XlaBuilder builder("test_computation"); + xla::XlaOp p = xla::Parameter(&builder, 0, shape, "param"); + xla::Add(p, p); + return builder.Build(); + }; + TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation computation, build_computation()); + + TF_ASSERT_OK_AND_ASSIGN(xla::DeviceAssignment assignment, + xla::ComputationPlacer().AssignDevices(1, 1)); + TF_ASSERT_OK_AND_ASSIGN(auto executable, + XrtExecutable::Compile(context, computation.proto(), + {shape}, shape, assignment)); + EXPECT_EQ(executable->device_assignment(), assignment); + TF_ASSERT_OK_AND_ASSIGN(auto c_buffer, executable->Execute({buffer})); + + xla::Literal expected = a.Clone(); + for (float& elem : expected.data()) { + elem *= 2; + } + TF_ASSERT_OK_AND_ASSIGN(xla::Literal out, c_buffer->ToLiteral()); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, out)); + + // Explicitly delete the executable, and then compile and run a different + // computation. + executable->Delete(); + + auto build_sub_computation = [&]() { + xla::XlaBuilder builder("test_computation"); + xla::XlaOp p = xla::Parameter(&builder, 0, shape, "p"); + xla::XlaOp q = xla::Parameter(&builder, 1, shape, "q"); + xla::Sub(p, q); + return builder.Build(); + }; + TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation sub_computation, + build_sub_computation()); + + TF_ASSERT_OK_AND_ASSIGN( + auto sub_executable, + XrtExecutable::Compile(context, sub_computation.proto(), {shape, shape}, + shape, assignment)); + TF_ASSERT_OK_AND_ASSIGN(auto buffer_out, + sub_executable->Execute({c_buffer, buffer})); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal sub_out, buffer_out->ToLiteral()); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(a, sub_out)); +} + +TEST_F(XrtClientTest, ErrorsPropagateCorrectly) { + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr context, MakeContext()); + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {3, 4, 5}); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal a, + xla::LiteralUtil::CreateRandomLiteral( + shape, + /*mean=*/7.0, /*stddev=*/13.5)); + TF_ASSERT_OK_AND_ASSIGN(auto buffer, XrtBuffer::FromLiteral(context, 0, a)); + + auto build_computation = [&]() { + xla::XlaBuilder builder("test_computation"); + xla::XlaOp p = xla::Parameter(&builder, 0, shape, "p"); + xla::XlaOp q = xla::Parameter(&builder, 1, shape, "q"); + xla::Add(p, q); + return builder.Build(); + }; + TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation computation, build_computation()); + + TF_ASSERT_OK_AND_ASSIGN(xla::DeviceAssignment assignment, + xla::ComputationPlacer().AssignDevices(1, 1)); + // Call Compile() with an arity mismatch. + TF_ASSERT_OK_AND_ASSIGN(auto sub_executable, + XrtExecutable::Compile(context, computation.proto(), + {shape}, shape, assignment)); + TF_ASSERT_OK_AND_ASSIGN(auto buffer_out, sub_executable->Execute({buffer})); + + // The compilation error should be reported when we consumer the computation's + // output. + EXPECT_FALSE(buffer_out->ToLiteral().ok()); + + // Further, we expect a clean shutdown at this point. + context = nullptr; +} + +TEST_F(XrtClientTest, TupleDestructuringAndDelete) { + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr context, MakeContext()); + + // Tests sending a literal to and from the device. + xla::Shape a_shape = xla::ShapeUtil::MakeShape(xla::F32, {3, 4, 5}); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal a, + xla::LiteralUtil::CreateRandomLiteral( + a_shape, + /*mean=*/7.0, /*stddev=*/13.5)); + + xla::Shape b_shape = xla::ShapeUtil::MakeShape(xla::F64, {2, 7}); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal b, + xla::LiteralUtil::CreateRandomLiteral( + b_shape, + /*mean=*/3.15, /*stddev=*/-2.1)); + xla::Literal tuple = xla::LiteralUtil::MakeTuple({&a, &b}); + TF_ASSERT_OK_AND_ASSIGN(auto buffer, + XrtBuffer::FromLiteral(context, 0, tuple)); + + TF_ASSERT_OK_AND_ASSIGN(std::vector> pieces, + buffer->DestructureTuple()); + + // Explicitly delete the tuple, which should have no effect on its + // constituents. + buffer->Delete(); + + TF_ASSERT_OK_AND_ASSIGN(xla::Literal a_out, pieces[0]->ToLiteral()); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal b_out, pieces[1]->ToLiteral()); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(a, a_out)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(b, b_out)); + + // Explicitly delete one of the pieces, use RAII to delete the other. + pieces[1]->Delete(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.cc b/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.cc index 5cbf57461ecf2f..e87233f52ffa6e 100644 --- a/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.cc +++ b/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.cc @@ -33,10 +33,11 @@ XrtGrpcEagerClient::XrtGrpcEagerClient(const SharedGrpcChannelPtr& channel, #define EAGER_CLIENT_METHOD(method) \ void XrtGrpcEagerClient::method##Async( \ const eager::method##Request* request, \ - eager::method##Response* response, StatusCallback done) { \ + eager::method##Response* response, StatusCallback done, \ + CallOptions* call_opts) { \ new RPCState( \ &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \ - response, std::move(done), nullptr, nullptr); \ + response, std::move(done), call_opts, nullptr); \ } EAGER_CLIENT_METHOD(CreateContext); @@ -49,12 +50,12 @@ EAGER_CLIENT_METHOD(SendTensor); #undef EAGER_CLIENT_METHOD #define WORKER_CLIENT_METHOD(method) \ - void XrtGrpcEagerClient::method##Async(const method##Request* request, \ - method##Response* response, \ - StatusCallback done) { \ + void XrtGrpcEagerClient::method##Async( \ + const method##Request* request, method##Response* response, \ + StatusCallback done, CallOptions* call_opts) { \ new RPCState( \ &stub_, cq_, "/tensorflow.WorkerService/" #method, *request, response, \ - std::move(done), nullptr, nullptr); \ + std::move(done), call_opts, nullptr); \ } WORKER_CLIENT_METHOD(GetStatus); diff --git a/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.h b/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.h index 5969c22675e5aa..7a04c5a1ea4d41 100644 --- a/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.h +++ b/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.h @@ -26,6 +26,7 @@ limitations under the License. #include "net/grpc/public/include/grpcpp/generic/generic_stub.h" #include "absl/synchronization/notification.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" #include "tensorflow/core/protobuf/eager_service.pb.h" #include "tensorflow/core/protobuf/worker.pb.h" @@ -51,23 +52,28 @@ class XrtGrpcEagerClient { void CreateContextAsync(const eager::CreateContextRequest* request, eager::CreateContextResponse* response, - StatusCallback done); + StatusCallback done, + CallOptions* call_opts = nullptr); void EnqueueAsync(const eager::EnqueueRequest* request, - eager::EnqueueResponse* response, StatusCallback done); + eager::EnqueueResponse* response, StatusCallback done, + CallOptions* call_opts = nullptr); void WaitQueueDoneAsync(const eager::WaitQueueDoneRequest* request, eager::WaitQueueDoneResponse* response, - StatusCallback done); + StatusCallback done, + CallOptions* call_opts = nullptr); void KeepAliveAsync(const eager::KeepAliveRequest* request, - eager::KeepAliveResponse* response, StatusCallback done); + eager::KeepAliveResponse* response, StatusCallback done, + CallOptions* call_opts = nullptr); void CloseContextAsync(const eager::CloseContextRequest* request, eager::CloseContextResponse* response, - StatusCallback done); + StatusCallback done, CallOptions* call_opts = nullptr); void RegisterFunctionAsync(const eager::RegisterFunctionRequest* request, eager::RegisterFunctionResponse* response, - StatusCallback done); + StatusCallback done, + CallOptions* call_opts = nullptr); void SendTensorAsync(const eager::SendTensorRequest* request, - eager::SendTensorResponse* response, - StatusCallback done); + eager::SendTensorResponse* response, StatusCallback done, + CallOptions* call_opts = nullptr); // The following two methods are actually from the WorkerService API, not // EagerService, but are necessary for using remote Eager, and we include them @@ -75,7 +81,8 @@ class XrtGrpcEagerClient { // We use RecvTensor to copy tensors back from a remote worker to the client. void RecvTensorAsync(const RecvTensorRequest* request, - RecvTensorResponse* response, StatusCallback done); + RecvTensorResponse* response, StatusCallback done, + CallOptions* call_opts = nullptr); // We use GetStatus to discover device incarnation values for use in // RecvTensor. @@ -83,17 +90,22 @@ class XrtGrpcEagerClient { // TFE server implementation. Remove this API call and use the device // information from CreateContext once the bug fix is deployed everywhere. void GetStatusAsync(const GetStatusRequest* request, - GetStatusResponse* response, StatusCallback done); + GetStatusResponse* response, StatusCallback done, + CallOptions* call_opts = nullptr); // Helper method for calling any of the ...Async methods synchronously. template - Status SyncCall(Method m, const Request* request, Response* response) { + Status SyncCall(Method m, const Request* request, Response* response, + CallOptions* call_opts = nullptr) { absl::Notification done; Status status; - (this->*(m))(request, response, [&](Status s) { - status = s; - done.Notify(); - }); + (this->*(m))( + request, response, + [&](Status s) { + status = s; + done.Notify(); + }, + call_opts); done.WaitForNotification(); return status; } diff --git a/tensorflow/compiler/xrt/client/xrt_tf_client.cc b/tensorflow/compiler/xrt/client/xrt_tf_client.cc index 5bbe23ebc23df0..32de6e68b9d184 100644 --- a/tensorflow/compiler/xrt/client/xrt_tf_client.cc +++ b/tensorflow/compiler/xrt/client/xrt_tf_client.cc @@ -211,8 +211,11 @@ void XrtTfContext::ReportError(absl::Span op_ids, while (!stack.empty()) { Operation* op = stack.top(); stack.pop(); + VLOG(10) << "Reporting error for " << op->id; for (const std::shared_ptr& future : op->tensor_futures) { + VLOG(10) << "Reporting error for " << op->id << " future"; + future->call_options_.StartCancel(); future->Notify(status); } for (OperationId consumer_id : op->consumers) { @@ -242,7 +245,7 @@ XrtTfContext::Operation* XrtTfContext::LookupOperation(OperationId id) { std::vector XrtTfContext::EnqueueOp( absl::string_view name, absl::Span inputs, int output_arity, protobuf::Map attrs, - int device_id) { + int device_id, std::shared_ptr future) { std::vector outputs; absl::MutexLock lock(&mu_); Operation* op = AddOperation(); @@ -261,6 +264,9 @@ std::vector XrtTfContext::EnqueueOp( outputs.push_back( XrtTensorHandle(shared_from_this(), device_id, TensorId{op->id, i})); } + if (future) { + op->tensor_futures.push_back(future); + } return outputs; } @@ -285,17 +291,25 @@ XrtTensorHandle XrtTfContext::SendTensor( request.set_device_name(devices_.at(rpc_device_id).name()); auto response = std::make_shared(); auto context_ptr = shared_from_this(); - eager_client_->SendTensorAsync(&request, response.get(), - [context_ptr, op_id, response](Status status) { - absl::MutexLock lock(&context_ptr->mu_); - if (!status.ok()) { - context_ptr->ReportError({op_id}, status); - } else { - context_ptr->DeleteOperation(op_id); - } - }); + absl::Notification done; + eager_client_->SendTensorAsync( + &request, response.get(), + [context_ptr, op_id, response, &done](Status status) { + absl::MutexLock lock(&context_ptr->mu_); + if (!status.ok()) { + context_ptr->ReportError({op_id}, status); + } else { + context_ptr->DeleteOperation(op_id); + } + done.Notify(); + }); XrtTensorHandle handle(context_ptr, rpc_device_id, TensorId{op_id, 0}); + // TODO(phawkins): we block here to avoid a race. We must not + // enqueue any dependent operations until the SendTensor has been + // acknowledged. + done.WaitForNotification(); + // TODO(phawkins): EagerService.SendTensor could use a host_memory option. if (!transfer_via_cpu_device) { return handle; @@ -334,11 +348,13 @@ static std::string GetRendezvousKey(absl::string_view send_device, std::shared_ptr XrtTfContext::RecvTensor( const XrtTensorHandle& tensor, DataType dtype, bool host_memory) { + auto response = std::make_shared(); + int device_id = tensor.device_id(); std::string wire_id = XrtGetUniqueWireID(); EnqueueSend(this, tensor, dtype, /*recv_device_id=*/-1, wire_id, - /*host_memory=*/false); + /*host_memory=*/false, /*future=*/response); const DeviceAttributes& device = devices().at(device_id); RecvTensorRequest request; @@ -347,12 +363,13 @@ std::shared_ptr XrtTfContext::RecvTensor( GetReceiverDevice(this, -1), device.incarnation(), wire_id)); - auto response = std::make_shared(); - eager_client_->RecvTensorAsync(&request, &response->value_, - [response](Status status) { - VLOG(10) << "RecvTensor complete\n"; - response->Notify(status); - }); + eager_client_->RecvTensorAsync( + &request, &response->value_, + [response](Status status) { + VLOG(10) << "RecvTensor complete\n"; + response->Notify(status); + }, + &response->call_options_); return response; } @@ -460,7 +477,8 @@ AttrValue MakeAttrValue(absl::Span dtypes) { void EnqueueSend(XrtTfContext* context, const XrtTensorHandle& tensor, DataType dtype, int recv_device_id, std::string wire_id, - bool host_memory) { + bool host_memory, + std::shared_ptr future) { protobuf::Map attrs; const DeviceAttributes& device = context->devices().at(tensor.device_id()); attrs["tensor_name"] = MakeAttrValue(wire_id); @@ -472,7 +490,8 @@ void EnqueueSend(XrtTfContext* context, const XrtTensorHandle& tensor, attrs["T"] = MakeAttrValue(dtype); context->EnqueueOp(host_memory ? "_HostSend" : "_Send", {&tensor}, - /*output_arity=*/0, std::move(attrs), tensor.device_id()); + /*output_arity=*/0, std::move(attrs), tensor.device_id(), + future); } XrtTensorHandle EnqueueRecv(XrtTfContext* context, DataType dtype, diff --git a/tensorflow/compiler/xrt/client/xrt_tf_client.h b/tensorflow/compiler/xrt/client/xrt_tf_client.h index f2a2e94bdb8daa..73caa6d746c12a 100644 --- a/tensorflow/compiler/xrt/client/xrt_tf_client.h +++ b/tensorflow/compiler/xrt/client/xrt_tf_client.h @@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -22,10 +22,6 @@ limitations under the License. // The API is intended to be minimal and does not take dependencies on classes // such as Tensor or Device. // -// The XrtTfClient/XrtTfContext classes are intended to wrap the TensorFlow API -// more directly, without any XRT-specific knowledge. The higher level XrtClient -// adds XRT-specific functionality on top. -// // The main feature this client adds over the remote eager TF client is // batching. Rather than synchronously executing each operator, the client // accumulates batches of operators and enqueues them as a unit. This is @@ -59,6 +55,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xrt/client/xrt_grpc_eager_client.h" #include "tensorflow/compiler/xrt/client/xrt_tf_client.h" +#include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/platform/env.h" @@ -88,6 +85,7 @@ class XrtTensorHandle; class XrtRecvTensorFuture; // Class that manages a TensorFlow Eager context. +// TODO(phawkins): Intended to be thread-safe class XrtTfContext : public std::enable_shared_from_this { public: struct Options { @@ -123,16 +121,17 @@ class XrtTfContext : public std::enable_shared_from_this { XrtTensorHandle SendTensor(std::unique_ptr tensor_proto, int device_id, bool host_memory = false); - // Receives `tensor` from the remote host. Also has the side effect of - // sending any enqueued operations to the remote worker. + // Receives `tensor` from the remote host. Does not flush the queue. std::shared_ptr RecvTensor(const XrtTensorHandle& tensor, DataType dtype, bool host_memory); // Enqueues an operator onto the remote host. + // 'future' is an optional future that depends on the op. std::vector EnqueueOp( absl::string_view name, absl::Span inputs, - int output_arity, protobuf::Map attrs, int device_id); + int output_arity, protobuf::Map attrs, int device_id, + std::shared_ptr future = {}); // Registers a function `def` on the remote host. Status RegisterFunction(const FunctionDef& def); @@ -295,6 +294,8 @@ class XrtRecvTensorFuture { absl::Notification done_; Status status_ GUARDED_BY(mu_); RecvTensorResponse value_ GUARDED_BY(mu_); + + CallOptions call_options_; }; // This gets a unique wire ID. We add a random identifier so that if the @@ -307,9 +308,12 @@ std::string XrtGetUniqueWireID(); // remote worker. If recv_device_id < 0 the target of the send is the client, // and a fake device name is used (since the client has no real name in the // TF cluster). +// 'future' may be null. If non-null it gives a future that depends on the +// output of the send and that must be aborted if the send fails. void EnqueueSend(XrtTfContext* context, const XrtTensorHandle& tensor, DataType dtype, int recv_device_id, std::string wire_id, - bool host_memory); + bool host_memory, + std::shared_ptr future = {}); // Enqueues a _Recv operator that receives a tensor onto a remote device. XrtTensorHandle EnqueueRecv(XrtTfContext* context, DataType dtype,