From a1d6179adb1ca6208281ed955860c319525edf75 Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Tue, 26 Jun 2018 18:51:41 -0700 Subject: [PATCH] [C++]: Ability to feed and fetch tensors while keeping them in device memory when using Session::RunCallable(). PiperOrigin-RevId: 202234757 --- tensorflow/core/BUILD | 8 +- .../common_runtime/direct_session_test.cc | 339 ++++++++++++++++-- .../common_runtime/graph_execution_state.cc | 176 ++++++++- .../rpc/grpc_session_test.cc | 33 ++ tensorflow/core/protobuf/config.proto | 64 +++- 5 files changed, 564 insertions(+), 56 deletions(-) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 0e411703672380..4bb1bf0dab329a 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -3911,13 +3911,13 @@ tf_cc_test( ], ) -tf_cc_test( +tf_cuda_cc_test( name = "common_runtime_direct_session_test", size = "small", srcs = ["common_runtime/direct_session_test.cc"], + args = [] + if_cuda(["--heap_check=local"]), # The GPU tracer leaks memory linkstatic = tf_kernel_tests_linkstatic(), deps = [ - ":core", ":core_cpu", ":core_cpu_internal", ":direct_session_internal", @@ -3930,6 +3930,7 @@ tf_cc_test( ":test", ":test_main", ":testlib", + "//third_party/eigen3", "//tensorflow/cc:cc_ops", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:cwise_op", @@ -3943,8 +3944,7 @@ tf_cc_test( "//tensorflow/core/kernels:queue_ops", "//tensorflow/core/kernels:session_ops", "//tensorflow/core/kernels:variable_ops", - "//third_party/eigen3", - ], + ] + if_cuda([":cuda"]), ) # This is identical to :common_runtime_direct_session_test with the addition of diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index 8ddc9958b2259f..5b424230ca550b 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" @@ -47,6 +48,11 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/device_name_utils.h" +#ifdef GOOGLE_CUDA +#include "cuda/include/cuda.h" +#include "cuda/include/cuda_runtime_api.h" +#endif // GOOGLE_CUDA + namespace tensorflow { namespace { @@ -1233,36 +1239,23 @@ TEST(DirectSessionTest, TimeoutSession) { device: '/device:CPU:0' attr { key: 'capacity' - value { - i: 10 - } + value { i: 10 } } attr { key: 'component_types' - value { - list { - type: DT_FLOAT - } - } + value { list { type: DT_FLOAT } } } attr { key: 'container' - value { - s: '' - } + value { s: '' } } attr { key: 'shapes' - value { - list { - } - } + value { list {} } } attr { key: 'shared_name' - value { - s: '' - } + value { s: '' } } } node { @@ -1272,24 +1265,15 @@ TEST(DirectSessionTest, TimeoutSession) { device: '/device:CPU:0' attr { key: 'component_types' - value { - list { - type: DT_FLOAT - } - } + value { list { type: DT_FLOAT } } } attr { key: 'timeout_ms' - value { - i: -1 - } + value { i: -1 } } } - versions { - producer: 9 - } - )proto", - &graph); + versions { producer: 9 } + )proto", &graph); { // Creates a session with operation_timeout_in_ms set to 100 milliseconds. @@ -1352,11 +1336,8 @@ TEST(DirectSessionTest, TestTimeoutCleanShutdown) { op: 'CancellationMgrPollingOp' device: '/device:CPU:0' } - versions { - producer: 9 - } - )proto", - &graph); + versions { producer: 9 } + )proto", &graph); // Creates a session with operation_timeout_in_ms set to 100 milliseconds. SessionOptions options; @@ -1730,6 +1711,292 @@ TEST(DirectSessionTest, LocalDeviceManager) { EXPECT_GT(mgr->ListDevices().size(), 0); } +// y = tf.square(x) +GraphDef CreateGraphForYEqualsXSquared() { + GraphDef graph_def; + QCHECK(protobuf::TextFormat::ParseFromString( + R"EOF( +node { + name: "x" + op: "Placeholder" + attr { key: "dtype" value { type: DT_FLOAT } } + attr { key: "shape" value { shape { unknown_rank: true } } } +} +node { + name: "y" + op: "Square" + input: "x" + attr { key: "T" value { type: DT_FLOAT } } +} +versions { + producer: 26 +} + )EOF", + &graph_def)); + return graph_def; +} + +// A graph that consumes and produces string tensors +// (which are not GPU-compatible, i.e., there are no +// GPU kernels for these operations). +bool IsCUDATensor(const Tensor& t) { +#ifdef GOOGLE_CUDA + cudaPointerAttributes attributes; + cudaError_t err = + cudaPointerGetAttributes(&attributes, t.tensor_data().data()); + if (err == cudaErrorInvalidValue) return false; + CHECK_EQ(cudaSuccess, err) << cudaGetErrorString(err); + return (attributes.memoryType == cudaMemoryTypeDevice); +#else + return false; +#endif +} + +string GPUDeviceName(Session* session) { + std::vector devices; + TF_CHECK_OK(session->ListDevices(&devices)); + for (const DeviceAttributes& d : devices) { + if (d.device_type() == "GPU" || d.device_type() == "gpu") { + return d.name(); + } + } + return ""; +} + +TEST(DirectSessionTest, FeedAndFetchTensorsInDeviceMemory) { + std::unique_ptr session(NewSession(SessionOptions())); + const string gpu_device_name = GPUDeviceName(session.get()); + if (gpu_device_name.empty()) { + LOG(INFO) << "Skipping test since no GPU is available"; + return; + } + + TF_ASSERT_OK(session->Create(CreateGraphForYEqualsXSquared())); + + CallableOptions opts; + opts.add_feed("x:0"); + opts.add_fetch("y:0"); + + Tensor gpu_tensor; + + { + Session::CallableHandle feed_cpu_fetch_gpu; + opts.mutable_fetch_devices()->insert({"y:0", gpu_device_name}); + opts.set_fetch_skip_sync(true); + TF_ASSERT_OK(session->MakeCallable(opts, &feed_cpu_fetch_gpu)); + Tensor input(DT_FLOAT, {}); + input.scalar()() = 2.0f; + std::vector outputs; + TF_ASSERT_OK( + session->RunCallable(feed_cpu_fetch_gpu, {input}, &outputs, nullptr)); + TF_ASSERT_OK(session->ReleaseCallable(feed_cpu_fetch_gpu)); + ASSERT_EQ(1, outputs.size()); + gpu_tensor = outputs[0]; + ASSERT_TRUE(IsCUDATensor(gpu_tensor)); + } + + { + Session::CallableHandle feed_gpu_fetch_cpu; + opts.clear_fetch_devices(); + opts.mutable_feed_devices()->insert({"x:0", gpu_device_name}); + TF_ASSERT_OK(session->MakeCallable(opts, &feed_gpu_fetch_cpu)); + std::vector outputs; + TF_ASSERT_OK(session->RunCallable(feed_gpu_fetch_cpu, {gpu_tensor}, + &outputs, nullptr)); + TF_ASSERT_OK(session->ReleaseCallable(feed_gpu_fetch_cpu)); + ASSERT_EQ(1, outputs.size()); + // The output is in CPU/host memory, so it can be dereferenced. + ASSERT_EQ(16.0, outputs[0].scalar()()); + } +} + +GraphDef CreateIdentityGraphDef(DataType dtype) { + GraphDef def; + + AttrValue dtype_attr; + dtype_attr.set_type(dtype); + + AttrValue shape_attr; + shape_attr.mutable_shape()->set_unknown_rank(true); + + auto* placeholder = def.add_node(); + placeholder->set_name("x"); + placeholder->set_op("Placeholder"); + placeholder->mutable_attr()->insert({"dtype", dtype_attr}); + placeholder->mutable_attr()->insert({"shape", shape_attr}); + + auto* identity = def.add_node(); + identity->set_name("y"); + identity->set_op("Identity"); + identity->add_input("x"); + identity->mutable_attr()->insert({"T", dtype_attr}); + + return def; +} + +void TestFeedAndFetchTensorsInDeviceMemory( + const SessionOptions& session_options, DataType dtype) { + std::unique_ptr session(NewSession(session_options)); + const string gpu_device_name = GPUDeviceName(session.get()); + if (gpu_device_name.empty()) { + LOG(INFO) << "Skipping test since no GPU is available"; + return; + } + + TF_ASSERT_OK(session->Create(CreateIdentityGraphDef(dtype))) + << DataType_Name(dtype); + + CallableOptions opts; + opts.add_feed("x:0"); + opts.add_fetch("y:0"); + + Tensor gpu_tensor; + Tensor host_tensor(dtype, {3}); + { + // Ask for the fetched tensor to be backed by device memory. + // Even though the kernel that created the tensor produced it in host + // memory. + opts.mutable_fetch_devices()->insert({"y:0", gpu_device_name}); + opts.set_fetch_skip_sync(true); + Session::CallableHandle handle; + TF_ASSERT_OK(session->MakeCallable(opts, &handle)) << DataType_Name(dtype); + std::vector outputs; + TF_ASSERT_OK(session->RunCallable(handle, {host_tensor}, &outputs, nullptr)) + << DataType_Name(dtype); + TF_ASSERT_OK(session->ReleaseCallable(handle)) << DataType_Name(dtype); + ASSERT_EQ(1, outputs.size()) << DataType_Name(dtype); + gpu_tensor = outputs[0]; + ASSERT_TRUE(IsCUDATensor(gpu_tensor)) << DataType_Name(dtype); + } + + { + // Feed a tensor backed by device memory, even though the operations in the + // graph expect it in host memory. + opts.clear_fetch_devices(); + opts.mutable_feed_devices()->insert({"x:0", gpu_device_name}); + Session::CallableHandle handle; + TF_ASSERT_OK(session->MakeCallable(opts, &handle)) << DataType_Name(dtype); + std::vector outputs; + TF_ASSERT_OK(session->RunCallable(handle, {gpu_tensor}, &outputs, nullptr)) + << DataType_Name(dtype); + TF_ASSERT_OK(session->ReleaseCallable(handle)) << DataType_Name(dtype); + ASSERT_EQ(1, outputs.size()); + const StringPiece actual_data = outputs[0].tensor_data(); + const StringPiece expected_data = host_tensor.tensor_data(); + EXPECT_EQ(expected_data.size(), actual_data.size()) << DataType_Name(dtype); + EXPECT_EQ(0, memcmp(expected_data.data(), actual_data.data(), + std::min(expected_data.size(), actual_data.size()))) + << DataType_Name(dtype); + } +} + +void TestFeedAndFetchTensorsInDeviceMemoryFailsToMakeCallable( + const SessionOptions& session_options, DataType dtype) { + std::unique_ptr session(NewSession(session_options)); + const string gpu_device_name = GPUDeviceName(session.get()); + if (gpu_device_name.empty()) { + LOG(INFO) << "Skipping test since no GPU is available"; + return; + } + + TF_ASSERT_OK(session->Create(CreateIdentityGraphDef(dtype))) + << DataType_Name(dtype); + + CallableOptions opts; + opts.add_feed("x:0"); + opts.add_fetch("y:0"); + + // Fail when asking to fetch into GPU memory. + { + opts.mutable_fetch_devices()->insert({"y:0", gpu_device_name}); + opts.set_fetch_skip_sync(true); + Session::CallableHandle handle; + Status status = session->MakeCallable(opts, &handle); + EXPECT_FALSE(status.ok()) << DataType_Name(dtype); + EXPECT_TRUE(str_util::StrContains( + status.error_message(), + strings::StrCat( + "Cannot feed or fetch tensor 'y:0' from device ", gpu_device_name, + " as feeding/fetching from GPU devices is not yet supported for ", + DataTypeString(dtype), " tensors"))) + << DataType_Name(dtype) << ", Status: " << status; + } + + // Fail when feeding from GPU memory. + { + opts.clear_feed_devices(); + opts.mutable_feed_devices()->insert({"x:0", gpu_device_name}); + Session::CallableHandle handle; + Status status = session->MakeCallable(opts, &handle); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains( + status.error_message(), + strings::StrCat( + "Cannot feed or fetch tensor 'x:0' from device ", gpu_device_name, + " as feeding/fetching from GPU devices is not yet supported for ", + DataTypeString(dtype), " tensors"))) + << DataType_Name(dtype) << ", Status: " << status; + } +} + +void TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes( + const SessionOptions& opts) { + // Feeding/fetching on device does not work for all DataTypes as it + // relies on the implementation of the _Arg and _Retval kernels which + // are not registered for some types or consume/produce inputs/outputs + // in host memory for some types. + // + // Run through all datatypes to validate that either: + // (a) MakeCallable fails (because the given type cannot be fed/fetched + // in device memory), + // OR + // (b) Succeeds: RunCallable should gladly accept inputs in device memory + // and produce output tensors in device memory. + for (int i = DataType_MIN; i <= DataType_MAX; ++i) { + if (!DataType_IsValid(i)) continue; + const DataType dtype = static_cast(i); + switch (dtype) { + case DT_INVALID: + break; + case DT_BFLOAT16: + case DT_BOOL: + case DT_COMPLEX128: + case DT_COMPLEX64: + case DT_DOUBLE: + case DT_FLOAT: + case DT_HALF: + case DT_INT16: + case DT_INT64: + case DT_INT8: + case DT_UINT16: + case DT_UINT8: + TestFeedAndFetchTensorsInDeviceMemory(opts, dtype); + break; + default: + // Ignore all REF types since Tensors of this type aren't intended to + // be fed (and attempting to create one via the Tensor constructor + // will result in a LOG(FATAL)). + if (!IsRefType(dtype)) { + TestFeedAndFetchTensorsInDeviceMemoryFailsToMakeCallable(opts, dtype); + } + break; + } + } +} + +TEST(DirectSessionTest, FeedAndFetchTensorsInDeviceMemory_AllDataTypes) { + SessionOptions opts; + opts.config.set_allow_soft_placement(false); + TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(opts); +} + +TEST(DirectSessionTest, + FeedAndFetchTensorsInDeviceMemory_AllDataTypes_SoftPlacement) { + SessionOptions opts; + opts.config.set_allow_soft_placement(true); + TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(opts); +} + // A simple benchmark for the overhead of `DirectSession::Run()` calls // with varying numbers of feeds/fetches. void FeedFetchBenchmarkHelper(int iters, int num_feeds, diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index 58018689d57c8d..9c9eacb5b5e2b5 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -280,6 +280,118 @@ class TensorConnectionPruneRewrite : public subgraph::PruneRewrite { NodeBuilder::NodeOut from_tensor_; }; +template +Status LookupDevice(const DeviceSet& device_set, const string& tensor_name, + const Map& tensor2device, + const tensorflow::DeviceAttributes** out_device_attrs) { + *out_device_attrs = nullptr; + if (tensor2device.empty()) { + *out_device_attrs = &device_set.client_device()->attributes(); + return Status::OK(); + } + const auto it = tensor2device.find(tensor_name); + if (it == tensor2device.end()) { + *out_device_attrs = &device_set.client_device()->attributes(); + return Status::OK(); + } + DeviceNameUtils::ParsedName parsed_name; + if (!DeviceNameUtils::ParseFullName(it->second, &parsed_name)) { + return errors::InvalidArgument("Invalid device name ('", it->second, + "') provided for the tensor '", tensor_name, + "' in CallableOptions"); + } + Device* device = device_set.FindDeviceByName( + DeviceNameUtils::ParsedNameToString(parsed_name)); + if (device == nullptr) { + return errors::InvalidArgument("Device '", it->second, + "' specified for tensor '", tensor_name, + "' in CallableOptions does not exist"); + } + *out_device_attrs = &device->attributes(); + return Status::OK(); +} + +struct TensorAndDevice { + // WARNING: backing memory for the 'tensor' field is NOT owend. + const TensorId tensor; + // WARNING: device pointer is not owned, so must outlive TensorAndDevice. + const DeviceAttributes* device; +}; + +// Tensors of some DataTypes cannot placed in device memory as feeds or +// fetches. Validate against a whitelist of those known to work. +bool IsFeedAndFetchSupported(DataType dtype, const string& device_type) { + // The mechanism for supporting feeds of device-backed Tensors requires + // the _Arg kernel to be registered for the corresponding type (and that + // the input to the kernel be in device and not host memory). + // + // The mechanism for supporting fetches of device-backed Tensors requires + // the _Retval kernel to be registered for the corresponding type (and + // that the output is produced in device and not host memory). + // + // For now, we return true iff there are _Arg AND _Retval kernels for dtype on + // the device. False negatives are okay, false positives would be bad. + // + // TODO(ashankar): Instead of a whitelist here, perhaps we could query + // the kernel registry for _Arg and _Retval kernels instead. + if (device_type == DEVICE_CPU) return true; + if (device_type != DEVICE_GPU) return false; + switch (dtype) { + case DT_BFLOAT16: + case DT_BOOL: + case DT_COMPLEX128: + case DT_COMPLEX64: + case DT_DOUBLE: + case DT_FLOAT: + case DT_HALF: + case DT_INT16: + case DT_INT64: + case DT_INT8: + case DT_UINT16: + case DT_UINT8: + return true; + default: + return false; + } +} + +Status ValidateFeedAndFetchDevices( + const Graph& graph, + const std::vector& tensors_and_devices) { + if (tensors_and_devices.empty()) return Status::OK(); + std::vector found(tensors_and_devices.size(), false); + for (const Node* node : graph.nodes()) { + // Linearly looping through all nodes and then all feed+fetch tensors isn't + // quite efficient. At the time of this writing, the expectation was that + // tensors_and_devices.size() is really small in practice, so this won't be + // problematic. + // Revist and make a more efficient lookup possible if needed (e.g., perhaps + // Graph can maintain a map from node name to Node*). + for (int i = 0; i < tensors_and_devices.size(); ++i) { + const TensorAndDevice& td = tensors_and_devices[i]; + if (td.tensor.first != node->name()) continue; + found[i] = true; + TF_RETURN_IF_ERROR(graph.IsValidOutputTensor(node, td.tensor.second)); + const DataType dtype = node->output_type(td.tensor.second); + if (!IsFeedAndFetchSupported(dtype, td.device->device_type())) { + return errors::Unimplemented( + "Cannot feed or fetch tensor '", td.tensor.ToString(), + "' from device ", td.device->name(), " as feeding/fetching from ", + td.device->device_type(), " devices is not yet supported for ", + DataTypeString(dtype), " tensors"); + } + } + } + for (int i = 0; i < found.size(); ++i) { + if (!found[i]) { + return errors::InvalidArgument( + "Tensor ", tensors_and_devices[i].tensor.ToString(), + ", specified in either feed_devices or fetch_devices was not found " + "in the Graph"); + } + } + return Status::OK(); +} } // namespace Status GraphExecutionState::PruneGraph( @@ -289,18 +401,52 @@ Status GraphExecutionState::PruneGraph( feed_rewrites.reserve(options.callable_options.feed_size()); std::vector> fetch_rewrites; fetch_rewrites.reserve(options.callable_options.fetch_size()); - const DeviceAttributes* device_info = - &device_set_->client_device()->attributes(); if (options.use_function_convention) { + std::vector tensors_and_devices; for (int i = 0; i < options.callable_options.feed_size(); ++i) { - feed_rewrites.emplace_back(new subgraph::ArgFeedRewrite( - &options.callable_options.feed(i), device_info, i)); + // WARNING: feed MUST be a reference, since ArgFeedRewrite and + // tensors_and_devices holds on to its address. + const string& feed = options.callable_options.feed(i); + const DeviceAttributes* device_info; + TF_RETURN_IF_ERROR(LookupDevice(*device_set_, feed, + options.callable_options.feed_devices(), + &device_info)); + feed_rewrites.emplace_back( + new subgraph::ArgFeedRewrite(&feed, device_info, i)); + tensors_and_devices.push_back({ParseTensorName(feed), device_info}); + } + if (!options.callable_options.fetch_devices().empty() && + !options.callable_options.fetch_skip_sync()) { + return errors::Unimplemented( + "CallableOptions.fetch_skip_sync = false is not yet implemented. You " + "can set it to true instead, but MUST ensure that Device::Sync() is " + "invoked on the Device corresponding to the fetched tensor before " + "dereferencing the Tensor's memory."); } for (int i = 0; i < options.callable_options.fetch_size(); ++i) { - fetch_rewrites.emplace_back(new subgraph::RetvalFetchRewrite( - &options.callable_options.fetch(i), device_info, i)); + // WARNING: fetch MUST be a reference, since RetvalFetchRewrite and + // tensors_and_devices holds on to its address. + const string& fetch = options.callable_options.fetch(i); + const DeviceAttributes* device_info; + TF_RETURN_IF_ERROR(LookupDevice(*device_set_, fetch, + options.callable_options.fetch_devices(), + &device_info)); + fetch_rewrites.emplace_back( + new subgraph::RetvalFetchRewrite(&fetch, device_info, i)); + tensors_and_devices.push_back({ParseTensorName(fetch), device_info}); } + TF_RETURN_IF_ERROR( + ValidateFeedAndFetchDevices(*graph, tensors_and_devices)); } else { + if (!options.callable_options.feed_devices().empty() || + !options.callable_options.fetch_devices().empty()) { + return errors::Unimplemented( + "CallableOptions::feed_devices and CallableOptions::fetch_devices " + "to configure feeding/fetching tensors to/from device memory is not " + "yet supported when using a remote session."); + } + const DeviceAttributes* device_info = + &device_set_->client_device()->attributes(); for (const string& feed : options.callable_options.feed()) { feed_rewrites.emplace_back( new subgraph::RecvFeedRewrite(&feed, device_info)); @@ -455,11 +601,11 @@ Status GraphExecutionState::OptimizeGraph( return errors::InvalidArgument("Missing node shape or type"); } TensorShapeProto shape_proto(node.attr().at("shape").shape()); - // If the shape of the placeholder value is only partially known, we're - // free to use any dimension we want to feed the placeholder. We choose - // 1 to minimize the memory impact. Note that this only matters if an - // optimizer choose to run the graph to build its cost model, which - // doesn't happen (yet) + // If the shape of the placeholder value is only partially known, + // we're free to use any dimension we want to feed the placeholder. We + // choose 1 to minimize the memory impact. Note that this only matters + // if an optimizer choose to run the graph to build its cost model, + // which doesn't happen (yet) if (shape_proto.unknown_rank()) { shape_proto.set_unknown_rank(false); } @@ -513,10 +659,10 @@ Status GraphExecutionState::OptimizeGraph( opts.allow_internal_ops = true; TF_RETURN_IF_ERROR( ConvertGraphDefToGraph(opts, new_graph, optimized_graph->get())); - // The graph conversion sets the requested device names but not the assigned - // device names. However, since at this point the graph is placed TF expects - // an assigned device name for every node. Therefore we copy the requested - // device into the assigned device field. + // The graph conversion sets the requested device names but not the + // assigned device names. However, since at this point the graph is placed + // TF expects an assigned device name for every node. Therefore we copy + // the requested device into the assigned device field. for (Node* node : optimized_graph->get()->nodes()) { node->set_assigned_device_name(node->requested_device()); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc index 45b15a54a29b48..fc601991a24d57 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc @@ -163,6 +163,39 @@ TEST(GrpcSessionTest, BasicCallable) { } } +TEST(GrpcSessionTest, CallableWithOnDeviceFeedsAndFetches) { + // Specifying feeds/fetch devices for remote sessions is not yet defined. + // Ensure that the error is graceful. + GraphDef graph; + string node_names[3]; + // c = a * b + CreateGraphDef(&graph, node_names); + + std::unique_ptr cluster; + TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster)); + + std::unique_ptr session( + NewRemote(Options(cluster->targets()[0], 1))); + ASSERT_TRUE(session != nullptr); + + TF_CHECK_OK(session->Create(graph)); + + std::vector devices; + TF_CHECK_OK(session->ListDevices(&devices)); + ASSERT_GT(devices.size(), 0); + const string device_name = devices.back().name(); + + CallableOptions opts; + const string fetch = node_names[2] + ":0"; + opts.add_fetch(fetch); + opts.mutable_fetch_devices()->insert({fetch, device_name}); + + Session::CallableHandle handle; + Status status = session->MakeCallable(opts, &handle); + EXPECT_EQ(error::UNIMPLEMENTED, status.code()); + TF_CHECK_OK(session->Close()); +} + TEST(GrpcSessionTest, BasicNonProtoAPIConsistentOrder) { GraphDef graph; string node_names[3]; diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index d83215d5c2a37e..7ea422187df575 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -490,5 +490,67 @@ message CallableOptions { // in the callable. repeated TensorConnection tensor_connection = 5; - // Next: 6 + // The Tensor objects fed in the callable and fetched from the callable + // are expected to be backed by host (CPU) memory by default. + // + // The options below allow changing that - feeding tensors backed by + // device memory, or returning tensors that are backed by device memory. + // + // The maps below map the name of a feed/fetch tensor (which appears in + // 'feed' or 'fetch' fields above), to the fully qualified name of the device + // owning the memory backing the contents of the tensor. + // + // For example, creating a callable with the following options: + // + // CallableOptions { + // feed: "a:0" + // feed: "b:0" + // + // fetch: "x:0" + // fetch: "y:0" + // + // feed_devices: { + // "a:0": "/job:localhost/replica:0/task:0/device:GPU:0" + // } + // + // fetch_devices: { + // "y:0": "/job:localhost/replica:0/task:0/device:GPU:0" + // } + // } + // + // means that the Callable expects: + // - The first argument ("a:0") is a Tensor backed by GPU memory. + // - The second argument ("b:0") is a Tensor backed by host memory. + // and of its return values: + // - The first output ("x:0") will be backed by host memory. + // - The second output ("y:0") will be backed by GPU memory. + // + // FEEDS: + // It is the responsibility of the caller to ensure that the memory of the fed + // tensors will be correctly initialized and synchronized before it is + // accessed by operations executed during the call to Session::RunCallable(). + // + // This is typically ensured by using the TensorFlow memory allocators + // (Device::GetAllocator()) to create the Tensor to be fed. + // + // Alternatively, for CUDA-enabled GPU devices, this typically means that the + // operation that produced the contents of the tensor has completed, i.e., the + // CUDA stream has been synchronized (e.g., via cuCtxSynchronize() or + // cuStreamSynchronize()). + map feed_devices = 6; + map fetch_devices = 7; + + // By default, RunCallable() will synchronize the GPU stream before returning + // fetched tensors on a GPU device, to ensure that the values in those tensors + // have been produced. This simplifies interacting with the tensors, but + // potentially incurs a performance hit. + // + // If this options is set to true, the caller is responsible for ensuring + // that the values in the fetched tensors have been produced before they are + // used. The caller can do this by invoking `Device::Sync()` on the underlying + // device(s), or by feeding the tensors back to the same Session using + // `feed_devices` with the same corresponding device name. + bool fetch_skip_sync = 8; + + // Next: 9 }