diff --git a/ngraph_bridge/enable_variable_ops/ngraph_assign_op.cc b/ngraph_bridge/enable_variable_ops/ngraph_assign_op.cc index 1213359e6..bcba19cce 100644 --- a/ngraph_bridge/enable_variable_ops/ngraph_assign_op.cc +++ b/ngraph_bridge/enable_variable_ops/ngraph_assign_op.cc @@ -58,7 +58,13 @@ class NGraphAssignOp : public OpKernel { // use_exclusive_lock_, validate_shape_, relax_constraints_; public: - ~NGraphAssignOp() { NGRAPH_VLOG(4) << "~NGraphAssignOp::" << name() << endl; } + ~NGraphAssignOp() { + NGRAPH_VLOG(4) << "~NGraphAssignOp::" << name() << endl; + // Delete from Input Variable Shared Name Map + string key = NGraphCatalog::CreateNodeKey(ng_graph_id_, name(), 0); + NGraphCatalog::DeleteFromInputVariableSharedNameMap(key); + } + explicit NGraphAssignOp(OpKernelConstruction* context) : OpKernel(context), is_tf_just_looking_(false), copy_to_tf_(false) { OP_REQUIRES_OK( @@ -124,19 +130,19 @@ class NGraphAssignOp : public OpKernel { // DO NOT CARE ABOUT SYNCING AS WE ARE ALWAYS SETTING THE NGTENSOR // Get input[1] - string valkey = to_string(ng_graph_id_) + "_" + def().input(1); - bool valref_exists = NGraphCatalog::ExistsInEncapOutputTensorMap(valkey); - if (valref_exists) { - // Value is from encap - NGRAPH_VLOG(4) << "NGraphAssign::Getting from catalog: " << valkey; - auto ng_val = NGraphCatalog::GetTensorFromEncapOutputTensorMap(valkey); - var->update_ng_tensor(ng_val); - } else { - NGRAPH_VLOG(4) << "NGraphAssign::Getting from TF : " << valkey; - if (var->update_ng_tensor(rhs_tensor)) { - number_of_copies++; - copy_log_str << " COPY_INP_VAL[0]"; - } + + // input[1] cannot be from NGraphEncap Op + // No way to get input node and check its type + string input_1_name = def().input(1); + OP_REQUIRES( + context, input_1_name.find("ngraph_cluster") == -1, + errors::Internal( + "Caught exception: Input to NGAssign from Encapsulate Op.\n")); + + NGRAPH_VLOG(4) << "NGraphAssign:: Updating"; + if (var->update_ng_tensor(rhs_tensor)) { + number_of_copies++; + copy_log_str << " COPY_INP_VAL[0]"; } mutex_lock l(*context->input_ref_mutex(0)); diff --git a/ngraph_bridge/enable_variable_ops/ngraph_catalog.cc b/ngraph_bridge/enable_variable_ops/ngraph_catalog.cc index 5f34ceee5..158d9cadb 100644 --- a/ngraph_bridge/enable_variable_ops/ngraph_catalog.cc +++ b/ngraph_bridge/enable_variable_ops/ngraph_catalog.cc @@ -29,25 +29,36 @@ namespace tensorflow { namespace ngraph_bridge { unordered_map NGraphCatalog::input_variable_sharedname_map_; -unordered_map> - NGraphCatalog::encap_output_tensor_map_; unordered_map> NGraphCatalog::encap_output_copy_indexes_map_; unordered_map> NGraphCatalog::encap_output_info_map_; +// Function to create the Node Key +string NGraphCatalog::CreateNodeKey(int graph_id, string node_name, int index) { + if (index == 0) { + return to_string(graph_id) + "_" + node_name; + } + return to_string(graph_id) + "_" + node_name + ":" + to_string(index); +} + // Functions for Encapsulate Output Copy Indexes Map -void NGraphCatalog::AddToEncapOutputCopyIndexesMap(string key, +void NGraphCatalog::AddToEncapOutputCopyIndexesMap(int graphid, + string node_name, unordered_set val) { + string key = graphid + "_" + node_name; NGraphCatalog::encap_output_copy_indexes_map_[key] = val; } unordered_set NGraphCatalog::GetEncapOutputIndexesThatNeedCopy( - string key) { + int graphid, string node_name) { + string key = graphid + "_" + node_name; return NGraphCatalog::encap_output_copy_indexes_map_[key]; } -bool NGraphCatalog::EncapOutputIndexNeedsCopy(string key, int index) { +bool NGraphCatalog::EncapOutputIndexNeedsCopy(int graphid, string node_name, + int index) { + string key = graphid + "_" + node_name; auto itr = NGraphCatalog::encap_output_copy_indexes_map_.find(key); if (itr != NGraphCatalog::encap_output_copy_indexes_map_.end()) { auto op_copy_indexes = itr->second; @@ -57,37 +68,10 @@ bool NGraphCatalog::EncapOutputIndexNeedsCopy(string key, int index) { return true; } -string NGraphCatalog::CreateNodeKey(int graph_id, string node_name, int index) { - if (index == 0) { - return to_string(graph_id) + "_" + node_name; - } - return to_string(graph_id) + "_" + node_name + ":" + to_string(index); -} - -// Functions for OutputTensorMap -void NGraphCatalog::AddToEncapOutputTensorMap( - string key, shared_ptr ng_val) { - NGraphCatalog::encap_output_tensor_map_[key] = ng_val; -} - -bool NGraphCatalog::ExistsInEncapOutputTensorMap(string key) { - auto itr = NGraphCatalog::encap_output_tensor_map_.find(key); - return itr != NGraphCatalog::encap_output_tensor_map_.end(); -} - -bool NGraphCatalog::ExistsInEncapOutputTensorMap(int graphid, string node_name, - int input_index) { - return NGraphCatalog::ExistsInEncapOutputTensorMap( - NGraphCatalog::CreateNodeKey(graphid, node_name, input_index)); -} - -shared_ptr -NGraphCatalog::GetTensorFromEncapOutputTensorMap(string key) { - return NGraphCatalog::encap_output_tensor_map_[key]; -} - -void NGraphCatalog::DeleteFromEncapOutputTensorMap(string key) { - NGraphCatalog::encap_output_tensor_map_.erase(key); +void NGraphCatalog::DeleteFromEncapOutputCopyIndexesMap(int graphid, + string node_name) { + string key = graphid + "_" + node_name; + NGraphCatalog::encap_output_copy_indexes_map_.erase(key); } // Functions relating Input Variable Shared Name Map @@ -114,6 +98,10 @@ bool NGraphCatalog::ExistsInInputVariableSharedNameMap(int graphid, NGraphCatalog::CreateNodeKey(graphid, node_name, input_index)); } +void NGraphCatalog::DeleteFromInputVariableSharedNameMap(string key) { + NGraphCatalog::input_variable_sharedname_map_.erase(key); +} + // Functions for EncapOutputInfo Map void NGraphCatalog::AddToEncapOutputInfoMap(string key, tuple val) { diff --git a/ngraph_bridge/enable_variable_ops/ngraph_catalog.h b/ngraph_bridge/enable_variable_ops/ngraph_catalog.h index 54699a21d..ce11bc530 100644 --- a/ngraph_bridge/enable_variable_ops/ngraph_catalog.h +++ b/ngraph_bridge/enable_variable_ops/ngraph_catalog.h @@ -50,25 +50,12 @@ class NGraphCatalog { // LOCK? static unordered_map input_variable_sharedname_map_; - // Map keeps track of nodes whose input is a tensor computed by NGraph - // For e.g. if the value to be assigned was computed by NGraphEncapsulate Op - // Will be used by Assign/Optimizers - // Map of - // Key - // when op index ==0 - // string : GraphId + _ + nodename - // otherwise - // string : GraphId + _ + nodename + : + output_index - // Value : shared_ptr - static unordered_map> - encap_output_tensor_map_; - // Map keeps track of output indexes of NGraphEncapsulate Op // that will be used by TF Nodes or other NGraphEncapsulate Op // Will be used by NGraphEncapsulateOP // Map of // Key - // string : nodename (nGraphEncapsulateOp name) + // string : GraphId + _ + nodename // Value : Set of indices static unordered_map> encap_output_copy_indexes_map_; @@ -91,12 +78,19 @@ class NGraphCatalog { encap_output_info_map_; public: + // Utility to create key to query the maps + static string CreateNodeKey(int graph_id, string node_name, int index); + // Utility Functions for the data structures // Functions for EncapsulateOutputCopyIndexes Map - static void AddToEncapOutputCopyIndexesMap(string key, + static void AddToEncapOutputCopyIndexesMap(int graphid, string node_name, unordered_set val); - static bool EncapOutputIndexNeedsCopy(string key, int index); - static unordered_set GetEncapOutputIndexesThatNeedCopy(string key); + static bool EncapOutputIndexNeedsCopy(int graphid, string node_name, + int index); + static unordered_set GetEncapOutputIndexesThatNeedCopy(int graphid, + string node_name); + static void DeleteFromEncapOutputCopyIndexesMap(int graphid, + string node_name); // Functions for InputVariableSharedName Map static string GetInputVariableSharedName(int graphid, string node_name, @@ -107,17 +101,7 @@ class NGraphCatalog { static bool ExistsInInputVariableSharedNameMap(string key); static bool ExistsInInputVariableSharedNameMap(int graphid, string node_name, int input_index); - - // Functions for EncapOutputTensorMap - static void AddToEncapOutputTensorMap(string key, - shared_ptr ng_val); - static bool ExistsInEncapOutputTensorMap(string key); - static bool ExistsInEncapOutputTensorMap(int graphid, string node_name, - int input_index); - - static shared_ptr GetTensorFromEncapOutputTensorMap( - string key); - static void DeleteFromEncapOutputTensorMap(string key); + static void DeleteFromInputVariableSharedNameMap(string key); // Functions for EncapOutputInfo Map static void AddToEncapOutputInfoMap(string key, @@ -134,9 +118,6 @@ class NGraphCatalog { static void DeleteFromEncapOutputInfoMap(string key); static void ClearEncapOutputInfoMap(); static void PrintEncapOutputInfoMap(); - - // Utility to create key to query the maps - static string CreateNodeKey(int graph_id, string node_name, int index); }; } // ngraph_bridge diff --git a/ngraph_bridge/enable_variable_ops/ngraph_enter_in_catalog.cc b/ngraph_bridge/enable_variable_ops/ngraph_enter_in_catalog.cc index 0a05ef043..1145f4cac 100644 --- a/ngraph_bridge/enable_variable_ops/ngraph_enter_in_catalog.cc +++ b/ngraph_bridge/enable_variable_ops/ngraph_enter_in_catalog.cc @@ -98,10 +98,12 @@ Status EnterInCatalog(Graph* graph, int graph_id) { NGRAPH_VLOG(4) << "Value: " << get<0>(value) << " " << get<1>(value) << " " << get<2>(value); NGraphCatalog::AddToEncapOutputInfoMap(key, value); - // TODO: Uncomment the continue when all the tasks are integrated - // continue; + // This NGraphAssign will be removed subsequently + // so we dont need to fill the rest of the catalog + continue; } } + // Update the input variable map if (IsNGVariableType(node->type_string())) { string node_key = NGraphCatalog::CreateNodeKey(graph_id, node->name(), 0); @@ -141,33 +143,10 @@ Status EnterInCatalog(Graph* graph, int graph_id) { op_index_to_copy.insert(edge->src_output()); } } - NGraphCatalog::AddToEncapOutputCopyIndexesMap(node->name(), + NGraphCatalog::AddToEncapOutputCopyIndexesMap(graph_id, node->name(), op_index_to_copy); } // end of node is type NGraphEncapsulate - - // Update the output tensor map - if (IsNGVariableType(node->type_string())) { - for (auto edge : node->in_edges()) { - if (!edge->src()->IsOp() || edge->IsControlEdge() || - IsRefType(edge->dst()->input_type(edge->dst_input())) || - edge->src()->type_string() != "NGraphEncapsulate") { - continue; - } - - NGRAPH_VLOG(4) << "Get " << node->type_string() - << " and input is from NGraphEncapsulate"; - - auto src = edge->src(); - int src_output = edge->src_output(); - string node_key = - NGraphCatalog::CreateNodeKey(graph_id, src->name(), src_output); - // Will be updated with real tensors in Encapsulate - NGraphCatalog::AddToEncapOutputTensorMap(node_key, nullptr); - NGRAPH_VLOG(4) << "Adding in Output Tensor Map"; - NGRAPH_VLOG(4) << "Key: " << node_key; - } - } // end of if node of type NGraphAssign } // enter in catalog NGRAPH_VLOG(4) << "Entered in Catalog"; diff --git a/ngraph_bridge/enable_variable_ops/ngraph_enter_in_catalog.h b/ngraph_bridge/enable_variable_ops/ngraph_enter_in_catalog.h index ba82559f4..1857446f1 100644 --- a/ngraph_bridge/enable_variable_ops/ngraph_enter_in_catalog.h +++ b/ngraph_bridge/enable_variable_ops/ngraph_enter_in_catalog.h @@ -50,11 +50,12 @@ namespace ngraph_bridge { // We add mapping of {graphId_nodename_InputIndex : Shared_Name} to the // InputVariableSharedNameMap // -// 2. If the output of NGraphEncapsulate Op is an input to NGraphVariableType -// Op, we store this NG-Tensor -// so that it can be directly accessed in compute call of NGraphVariableType. -// We add mapping of {graphId_encapnodename_OutputIndex : NG-Tensor} to the -// EncapOutputTensorMap +// 2. If the input to NGraphAssign Op is from NGraphEncapsulate Op +// We add mapping of +// {graphId_encapnodename_OutputIndex : tuple:{Variable_Shared_Name, CopyToTF, +// IsTFJustLooking}} +// to the EncapOutputInfoMap +// We attach "_ngraph_remove" attribute to this NGraphAssign node // // 3. If the output of NGraphEncapsulate Op is not required by a TF Op or // NGraphEncapsulate Op, diff --git a/ngraph_bridge/enable_variable_ops/ngraph_remove_ngraphassigns.cc b/ngraph_bridge/enable_variable_ops/ngraph_remove_ngraphassigns.cc index 9752248be..f18804978 100644 --- a/ngraph_bridge/enable_variable_ops/ngraph_remove_ngraphassigns.cc +++ b/ngraph_bridge/enable_variable_ops/ngraph_remove_ngraphassigns.cc @@ -47,6 +47,7 @@ Status RemoveNGraphAssigns(Graph* graph) { // Handle input edges NGRAPH_VLOG(3) << "Handling input edges "; + vector remove_edges; for (auto edge : node->in_edges()) { // attach incoming control edge to input_1, as that's where update // will happen @@ -55,8 +56,8 @@ Status RemoveNGraphAssigns(Graph* graph) { if (edge->src() == input_1) continue; graph->AddEdge(edge->src(), edge->src_output(), input_1, edge->dst_input()); - graph->RemoveEdge(edge); } + remove_edges.push_back(edge); } // Handle output edges @@ -80,6 +81,10 @@ Status RemoveNGraphAssigns(Graph* graph) { graph->AddEdge(input_1, Graph::kControlSlot, edge->dst(), Graph::kControlSlot); } + remove_edges.push_back(edge); + } + + for (auto edge : remove_edges) { graph->RemoveEdge(edge); } diff --git a/ngraph_bridge/enable_variable_ops/ngraph_rewrite_pass.cc b/ngraph_bridge/enable_variable_ops/ngraph_rewrite_pass.cc index 98e1067f5..603dbc50f 100644 --- a/ngraph_bridge/enable_variable_ops/ngraph_rewrite_pass.cc +++ b/ngraph_bridge/enable_variable_ops/ngraph_rewrite_pass.cc @@ -22,6 +22,7 @@ #include "logging/ngraph_log.h" #include "logging/tf_graph_writer.h" #include "ngraph_bridge/enable_variable_ops/ngraph_enter_in_catalog.h" +#include "ngraph_bridge/enable_variable_ops/ngraph_remove_ngraphassigns.h" #include "ngraph_bridge/enable_variable_ops/ngraph_replace_variable_modifiers.h" #include "ngraph_bridge/ngraph_api.h" #include "ngraph_bridge/ngraph_assign_clusters.h" @@ -206,16 +207,22 @@ class NGraphVariableCapturePass : public NGraphRewritePass { // 2. Cluster Assignment [ngraph_assign_clusters.cc] // 3. Cluster Deassignment [ngraph_deassign_clusters.cc] // 4. Cluster Encapsulation [ngraph_encapsulate_clusters.cc] -// +// 5. Rewrite Variable Type Ops for Tracking [ngraph_rewrite_for_tracking.cc] +// 6. Enter In Catalog [ngraph_enter_in_catalog.cc] +// 7. Remove NGraphAssigns [ngraph_remove_ngraphassigns.cc] // Between phases, graph dumps (in both .dot and .pbtxt format) may be // requested by setting the following environment variables: // -// NGRAPH_TF_DUMP_UNMARKED_GRAPHS=1 dumps graphs before phase 1 -// NGRAPH_TF_DUMP_MARKED_GRAPHS=1 dumps graphs after phase 1 -// NGRAPH_TF_DUMP_CLUSTERED_GRAPHS=1 dumps graphs after phase 2 -// NGRAPH_TF_DUMP_DECLUSTERED_GRAPHS=1 dumps graphs after phase 3 -// NGRAPH_TF_DUMP_ENCAPSULATED_GRAPHS=1 dumps graphs after phase 4 -// NGRAPH_TF_DUMP_GRAPHS=1 all of the above +// NGRAPH_TF_DUMP_UNMARKED_GRAPHS=1 dumps graphs before phase 0 +// NGRAPH_TF_DUMP_REPLACEDMODIFIERS_GRAPHS=1 dumps graphs after phase 0 +// NGRAPH_TF_DUMP_MARKED_GRAPHS=1 dumps graphs after phase 1 +// NGRAPH_TF_DUMP_CLUSTERED_GRAPHS=1 dumps graphs after phase 2 +// NGRAPH_TF_DUMP_DECLUSTERED_GRAPHS=1 dumps graphs after phase 3 +// NGRAPH_TF_DUMP_ENCAPSULATED_GRAPHS=1 dumps graphs after phase 4 +// NGRAPH_TF_DUMP_TRACKED_GRAPHS=1 dumps graphs after phase 5 +// NGRAPH_TF_DUMP_CATALOGED_GRAPHS=1 dumps graphs after phase 6 +// NGRAPH_TF_DUMP_REMOVENGASSIGNS_GRAPHS=1 dumps graphs after phase 7 +// NGRAPH_TF_DUMP_GRAPHS=1 all of the above // class NGraphEncapsulationPass : public NGraphRewritePass { public: @@ -323,6 +330,13 @@ class NGraphEncapsulationPass : public NGraphRewritePass { "Graph with Variables Inputs Entered in Catalog"); } + // Remove Certain NGraphAssigns then. + TF_RETURN_IF_ERROR(RemoveNGraphAssigns(options.graph->get())); + if (DumpRemoveNGraphAssignsGraphs()) { + DumpGraphs(options, idx, "ngraphssigns_optimized", + "Graph with NGraphAssigns Optimized/Removed"); + } + return Status::OK(); } @@ -360,6 +374,11 @@ class NGraphEncapsulationPass : public NGraphRewritePass { return DumpAllGraphs() || std::getenv("NGRAPH_TF_DUMP_CATALOGED_GRAPHS") != nullptr; } + + static bool DumpRemoveNGraphAssignsGraphs() { + return DumpAllGraphs() || + std::getenv("NGRAPH_TF_DUMP_REMOVENGASSIGNS_GRAPHS") != nullptr; + } }; } // namespace ngraph_bridge diff --git a/ngraph_bridge/enable_variable_ops/ngraph_tracked_variable.cc b/ngraph_bridge/enable_variable_ops/ngraph_tracked_variable.cc index aad4a2eff..77a8e3f95 100644 --- a/ngraph_bridge/enable_variable_ops/ngraph_tracked_variable.cc +++ b/ngraph_bridge/enable_variable_ops/ngraph_tracked_variable.cc @@ -23,6 +23,7 @@ #include "ngraph/event_tracing.hpp" #include "ngraph/runtime/backend.hpp" +#include "ngraph_bridge/enable_variable_ops/ngraph_catalog.h" #include "ngraph_bridge/enable_variable_ops/ngraph_var.h" #include "ngraph_bridge/ngraph_backend_manager.h" #include "ngraph_bridge/ngraph_freshness_tracker.h" @@ -111,6 +112,8 @@ NGraphVariableOp::NGraphVariableOp(OpKernelConstruction* context) NGraphVariableOp::~NGraphVariableOp() { NGRAPH_VLOG(4) << "~NGraphVariableOp:: " << name() << endl; + string node_key = NGraphCatalog::CreateNodeKey(ng_graph_id_, name(), 0); + NGraphCatalog::DeleteFromInputVariableSharedNameMap(node_key); tracker_->Unref(); } diff --git a/ngraph_bridge/ngraph_encapsulate_op.cc b/ngraph_bridge/ngraph_encapsulate_op.cc index 735625299..01ae17be5 100644 --- a/ngraph_bridge/ngraph_encapsulate_op.cc +++ b/ngraph_bridge/ngraph_encapsulate_op.cc @@ -221,13 +221,29 @@ class NGraphEncapsulateOp : public OpKernel { } #if defined(NGRAPH_TF_ENABLE_VARIABLES_AND_OPTIMIZERS) + // Remove Entries from Catalog + // Remove entries related to outputs for (int i = 0; i < m_number_outputs; i++) { string key = NGraphCatalog::CreateNodeKey(m_graph_id, name(), i); - if (NGraphCatalog::ExistsInEncapOutputTensorMap(key)) { - NGraphCatalog::DeleteFromEncapOutputTensorMap(key); - NGRAPH_VLOG(2) << "Deleting from output tensor map " << key; + if (NGraphCatalog::ExistsInEncapOutputInfoMap(key)) { + NGraphCatalog::DeleteFromEncapOutputInfoMap(key); + NGRAPH_VLOG(2) << "Deleting from output info map " << key; } } + + NGRAPH_VLOG(2) << "Deleting from Output Copy Index map " << name(); + NGraphCatalog::DeleteFromEncapOutputCopyIndexesMap(m_graph_id, name()); + + // Remove entries related to inputs + for (int i = 0; i < m_number_inputs; i++) { + string key = NGraphCatalog::CreateNodeKey(m_graph_id, name(), i); + if (NGraphCatalog::ExistsInInputVariableSharedNameMap(key)) { + NGraphCatalog::DeleteFromInputVariableSharedNameMap(key); + NGRAPH_VLOG(2) << "Deleting from input variable shared name map " + << key; + } + } + #endif // Release the backend @@ -674,13 +690,40 @@ class NGraphEncapsulateOp : public OpKernel { output_caches[i].second; void* current_dst_ptr = DMAHelper::base(output_tensor); - std::shared_ptr current_ng_tensor = - GetCurrentNgTensor(current_dst_ptr, last_dst_ptr, last_ng_tensor, - true, ng_exec, op_backend, ng_element_type, - ng_shape); + std::shared_ptr current_ng_tensor = nullptr; +// if the output tensor is going to be assigned to a variable +// we ask nGraph to provide the output directly in the variable tensor +#if defined(NGRAPH_TF_ENABLE_VARIABLES_AND_OPTIMIZERS) + if (NGraphCatalog::ExistsInEncapOutputInfoMap(m_graph_id, name(), i)) { + string output_key = NGraphCatalog::CreateNodeKey(m_graph_id, name(), i); + string ref_var_name = + NGraphCatalog::GetVariableSharedNameFromEncapOutputInfoMap( + output_key); + NGraphVar* var; + OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup( + ctx->resource_manager()->default_container(), + ref_var_name, &var)); + current_ng_tensor = var->ng_tensor(); + + // There might be scenarios where the input and output tensors are the + // same.The staleness determined for the input tensor should be the + // final staleness for the given tensor. The staleness of output + // tensor should not matter as this tensor is meant to be + // overwritten with the computed value. + // So not setting staleness here. + output_caches[i] = std::make_pair(current_dst_ptr, current_ng_tensor); + var->Unref(); + ng_outputs.push_back(current_ng_tensor); + continue; + } +#endif + current_ng_tensor = GetCurrentNgTensor( + current_dst_ptr, last_dst_ptr, last_ng_tensor, true, ng_exec, + op_backend, ng_element_type, ng_shape); current_ng_tensor->set_stale(true); output_caches[i] = std::make_pair(current_dst_ptr, current_ng_tensor); + ng_outputs.push_back(current_ng_tensor); } @@ -795,22 +838,50 @@ class NGraphEncapsulateOp : public OpKernel { #if defined(NGRAPH_TF_ENABLE_VARIABLES_AND_OPTIMIZERS) if (m_number_outputs == -1) { NGRAPH_VLOG(4) << "Settig number of outputs for " << def().name(); - m_number_outputs = output_caches.size(); + m_number_outputs = ng_outputs.size(); + NGRAPH_VLOG(4) << "Setting number of inputs for " << def().name(); + m_number_inputs = ng_inputs.size(); } for (size_t i = 0; i < output_tensor_count; ++i) { - string key = NGraphCatalog::CreateNodeKey(m_graph_id, def().name(), i); - bool ref_exists = NGraphCatalog::ExistsInEncapOutputTensorMap(key); - void* dst_ptr; - std::shared_ptr dst_ng_tensor; - std::tie(dst_ptr, dst_ng_tensor) = output_caches[i]; + // Sync the Var Tensor if required + string output_key = + NGraphCatalog::CreateNodeKey(m_graph_id, def().name(), i); + bool ref_exists = NGraphCatalog::ExistsInEncapOutputInfoMap(output_key); if (ref_exists) { - NGRAPH_VLOG(4) << "Adding in output tensor map " << key; - NGraphCatalog::AddToEncapOutputTensorMap(key, dst_ng_tensor); + NGRAPH_VLOG(4) << "Syncing the output var tensor " << output_key; + + // Get var + string ref_var_name = + NGraphCatalog::GetVariableSharedNameFromEncapOutputInfoMap( + output_key); + NGraphVar* var; + OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup( + ctx->resource_manager()->default_container(), + ref_var_name, &var)); + + if (NGraphCatalog::GetCopyToTFFromEncapOutputInfoMap(output_key)) { + if (var->copy_ng_to_tf()) { + number_of_copies++; + copy_log_str << " COPY_TF "; + } + if (!NGraphCatalog::GetIsTFJustLookingFromEncapOutputInfoMap( + output_key)) { + // Some tf op might update the ng-tensor value so mark it stale + copy_log_str << " SET_SYNC "; + var->set_sync_ng_tensor(true); + } + } + var->Unref(); } + std::shared_ptr dst_ng_tensor; + void* dst_ptr; + std::tie(dst_ptr, dst_ng_tensor) = output_caches[i]; + if (m_op_backend_name != "CPU" && - NGraphCatalog::EncapOutputIndexNeedsCopy(def().name(), i)) { + NGraphCatalog::EncapOutputIndexNeedsCopy(m_graph_id, def().name(), + i)) { number_of_copies++; copy_log_str << " COPY_OP_VAL[" << i << "]"; @@ -987,6 +1058,7 @@ class NGraphEncapsulateOp : public OpKernel { static int s_instance_count; int my_instance_id{0}; int m_number_outputs = -1; + int m_number_inputs = -1; }; int NGraphEncapsulateOp::s_instance_count = 0; @@ -996,4 +1068,4 @@ int NGraphEncapsulateOp::s_instance_count = 0; REGISTER_KERNEL_BUILDER(Name("NGraphEncapsulate").Device(DEVICE_CPU), ngraph_bridge::NGraphEncapsulateOp); -} // namespace tensorflow \ No newline at end of file +} // namespace tensorflow