Skip to content

Commit

Permalink
Sets up a rendezvous for eager op execution.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 165019327
  • Loading branch information
alextp authored and tensorflower-gardener committed Aug 11, 2017
1 parent 3715cf6 commit 37c54be
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 3 deletions.
1 change: 1 addition & 0 deletions tensorflow/c/eager/BUILD
Expand Up @@ -15,6 +15,7 @@ cc_library(
"//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
Expand Down
8 changes: 7 additions & 1 deletion tensorflow/c/eager/c_api.cc
Expand Up @@ -27,6 +27,8 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/gtl/map_util.h"
Expand All @@ -53,6 +55,7 @@ struct TFE_Context {

// TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph.
TF_Session* session;
tensorflow::Rendezvous* rendezvous;

tensorflow::mutex functions_mu;
tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){
Expand Down Expand Up @@ -135,6 +138,8 @@ TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) {
ret->session->device_mgr, opts->options.env, ret->devices()[i],
TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {});
}
ret->rendezvous =
new tensorflow::IntraProcessRendezvous(ret->session->device_mgr);

return ret;
}
Expand All @@ -145,6 +150,7 @@ void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
TF_Graph* graph = ctx->session->graph;
TF_DeleteSession(ctx->session, status);
TF_DeleteGraph(graph);
ctx->rendezvous->Unref();
delete ctx;
}

Expand Down Expand Up @@ -470,7 +476,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key);
if (kernel == nullptr) {
const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
kernel = new tensorflow::KernelAndDevice();
kernel = new tensorflow::KernelAndDevice(ctx->rendezvous);
if (!op->is_function()) {
status->status =
tensorflow::KernelAndDevice::InitOp(device, ndef, kernel);
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/c/eager/runtime.cc
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/eager/runtime.h"

#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/errors.h"
Expand Down Expand Up @@ -270,6 +271,7 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
params.output_attr_array = gtl::vector_as_array(&out_attrs);
params.function_library = flib_;
params.slice_reader_cache = &slice_reader_cache_;
params.rendezvous = rendez_;
// TODO(apassos): use a thread pool.
std::function<void(std::function<void()>)> runner =
[](std::function<void()> f) { f(); };
Expand Down
4 changes: 3 additions & 1 deletion tensorflow/c/eager/runtime.h
Expand Up @@ -174,7 +174,8 @@ class KernelAndDevice {
static Status InitFn(const NodeDef& ndef, FunctionLibraryRuntime* flib,
KernelAndDevice* out);

KernelAndDevice() : device_(nullptr), flib_(nullptr) {}
KernelAndDevice(tensorflow::Rendezvous* rendez)
: device_(nullptr), flib_(nullptr), rendez_(rendez) {}

// TODO(ashankar): Handle list-valued inputs.
Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs);
Expand All @@ -186,6 +187,7 @@ class KernelAndDevice {
tensorflow::Device* device_;
tensorflow::FunctionLibraryRuntime* flib_;
tensorflow::checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
tensorflow::Rendezvous* rendez_;
};

} // namespace tensorflow
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/c/eager/runtime_test.cc
Expand Up @@ -70,7 +70,7 @@ TEST(KernelAndDevice, Run) {
.NumInputs(inputs.size())
.BuildNodeDef());
std::unique_ptr<Device> device(CPUDevice());
KernelAndDevice kernel;
KernelAndDevice kernel(nullptr);
Status s = KernelAndDevice::InitOp(device.get(), ndef, &kernel);
ASSERT_TRUE(s.ok()) << s;
std::vector<Tensor> outputs;
Expand Down

0 comments on commit 37c54be

Please sign in to comment.