Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ngraph_bridge/ngraph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "ngraph/op/util/logical_reduction.hpp"

#include "logging/ngraph_log.h"
#include "ngraph_bridge/ngraph_api.h"
#include "ngraph_bridge/ngraph_backend_manager.h"
#include "ngraph_bridge/ngraph_builder.h"
#include "ngraph_bridge/ngraph_conversions.h"
Expand Down Expand Up @@ -97,6 +98,9 @@ std::shared_ptr<TOpType> ConstructNgNode(const std::string& op_name,
auto ng_node = std::make_shared<TOpType>(std::forward<TArg>(Args)...);
ng_node->set_friendly_name(op_name);
ng_node->add_provenance_tag(op_name);
if (config::IsLoggingPlacement()) {
cout << "TF_to_NG: " << op_name << " --> " << ng_node->get_name() << "\n";
}
return ng_node;
}

Expand Down
14 changes: 9 additions & 5 deletions ngraph_bridge/ngraph_encapsulate_clusters.cc
Original file line number Diff line number Diff line change
Expand Up @@ -763,10 +763,6 @@ Status EncapsulateClusters(
TF_RETURN_IF_ERROR(
ConvertGraphDefToGraph(opts, *gdef_for_current_encapsulate,
&graph_for_current_encapsulate));
TF_RETURN_IF_ERROR(Builder::TranslateGraph(
input_shapes, static_input_map, &graph_for_current_encapsulate,
ng_function));
string serialized_ngfunc(ngraph::serialize(ng_function, 4));

// get backend.
// TODO: Note that this is code duplication of some stuff present
Expand All @@ -793,6 +789,14 @@ Status EncapsulateClusters(
}
TF_RETURN_IF_ERROR(BackendManager::CreateBackend(
op_backend_name)); // Created a backend here. must free it
// TranslateGraph must be called AFTER CreateBackend because some TF
// ops like CNMS and gather use backend specific nodes
TF_RETURN_IF_ERROR(Builder::TranslateGraph(
input_shapes, static_input_map, &graph_for_current_encapsulate,
ng_function));
int json_indentation = 4;
string serialized_ngfunc(
ngraph::serialize(ng_function, json_indentation));
std::unordered_map<std::string, std::string> additional_attribute_map;
for (auto itr : node->attrs()) {
// Find the optional attributes to be sent to the backend.
Expand All @@ -804,7 +808,7 @@ Status EncapsulateClusters(
// For e.g. _ngraph_ice_cores --> ice_cores
if (itr.first.find("_ngraph_") != std::string::npos) {
// leave out _ngraph_aot_requested
if (itr.first.find("_ngraph_aot_requested") !=
if (itr.first.find("_ngraph_aot_requested") ==
std::string::npos) {
additional_attribute_map.insert(
{itr.first.substr(strlen("_ngraph_")), itr.second.s()});
Expand Down
33 changes: 19 additions & 14 deletions ngraph_bridge/ngraph_encapsulate_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,32 +134,32 @@ Status NGraphEncapsulateImpl::GetNgExecutable(
MemoryProfile(vm0, rss0);

NGRAPH_VLOG(1) << "Compilation cache miss: " << m_name;
string serialized_ng_func;
if (!m_do_aot) {
TF_RETURN_IF_ERROR(Builder::TranslateGraph(input_shapes, static_input_map,
&m_graph, ng_function));
ng_function->set_friendly_name(m_name);
int json_indentation = 4;
serialized_ng_func = ngraph::serialize(ng_function, json_indentation);
} else {
auto itr = m_aot_functions.find(signature);
if (itr == m_aot_functions.end()) {
return errors::Internal(
"Expected to find AOT precompiled ng function of signature: ",
signature);
}
ng_function = ng::deserialize(itr->second);
serialized_ng_func = itr->second;
}

auto function_size = ng_function->get_graph_size() / 1024; // kb unit

// Serialize to nGraph if needed
if (std::getenv("NGRAPH_ENABLE_SERIALIZE") != nullptr) {
std::string file_name = "tf_function_" + m_name + ".json";
NgraphSerialize("tf_function_" + m_name + ".json", ng_function);
StringToFile("tf_function_" + m_name + ".json", serialized_ng_func);
#if defined NGRAPH_DISTRIBUTED
int rank_id;
rank_id = ng::get_distributed_interface()->get_rank();
NgraphSerialize(
"tf_function_" + m_name + "_" + to_string(rank_id) + ".json",
ng_function);
StringToFile("tf_function_" + m_name + "_" + to_string(rank_id) + ".json",
serialized_ng_func);
#endif
}
// Evict the cache if the number of elements exceeds the limit
Expand All @@ -172,7 +172,7 @@ Status NGraphEncapsulateImpl::GetNgExecutable(
int input_tensors_bytes_free = 0;
evicted_ng_exec = m_ng_exec_map[m_lru.back()];
m_ng_exec_map.erase(m_lru.back());
m_ng_function_map.erase(evicted_ng_exec);
m_serialized_ng_function_map.erase(evicted_ng_exec);

// Call delete function here for the erased func
op_backend->remove_compiled_function(evicted_ng_exec);
Expand Down Expand Up @@ -222,12 +222,12 @@ Status NGraphEncapsulateImpl::GetNgExecutable(
}
} catch (const std::exception& exp) {
BackendManager::UnlockBackend(m_op_backend_name);
NgraphSerialize("tf_function_error_" + m_name + ".json", ng_function);
StringToFile("tf_function_error_" + m_name + ".json", serialized_ng_func);
return errors::Internal("Caught exception while compiling op_backend: ",
exp.what(), "\n");
} catch (...) {
BackendManager::UnlockBackend(m_op_backend_name);
NgraphSerialize("tf_function_error_" + m_name + ".json", ng_function);
StringToFile("tf_function_error_" + m_name + ".json", serialized_ng_func);
return errors::Internal("Error in compiling op_backend\n");
}
BackendManager::UnlockBackend(m_op_backend_name);
Expand All @@ -236,7 +236,7 @@ Status NGraphEncapsulateImpl::GetNgExecutable(

SetNgExecMap(signature, ng_exec);
// caching ng_function to serialize to ngraph if needed
SetNgFunctionMap(ng_exec, ng_function);
m_serialized_ng_function_map[ng_exec] = serialized_ng_func;

m_lru.push_front(signature);
// Memory after
Expand All @@ -245,9 +245,8 @@ Status NGraphEncapsulateImpl::GetNgExecutable(
auto delta_res_mem = rss - rss0;
NGRAPH_VLOG(1) << "NGRAPH_TF_CACHE_PROFILE: OP_ID: " << my_instance_id
<< " Cache length: " << m_ng_exec_map.size()
<< " Cluster: " << m_name << " Delta VM: " << delta_vm_mem
<< " Delta RSS: " << delta_res_mem
<< " Function size: " << function_size
<< " Cluster: " << m_name << " Delta VM: " << delta_vm_mem
<< " Delta RSS: " << delta_res_mem
<< " KB Total RSS: " << rss / (1024 * 1024) << " GB "
<< " VM: " << vm / (1024 * 1024) << " GB" << endl;
} // end of input signature not found in m_ng_exec_map
Expand Down Expand Up @@ -582,6 +581,12 @@ NGraphEncapsulateImpl::GetTensorsFromPipeline(
return out_tpl;
}

void NGraphEncapsulateImpl::DumpNgFunction(
const string& file_name,
std::shared_ptr<ngraph::runtime::Executable> ng_exec) {
StringToFile(file_name, m_serialized_ng_function_map[ng_exec]);
}

} // namespace ngraph_bridge

} // namespace tensorflow
25 changes: 9 additions & 16 deletions ngraph_bridge/ngraph_encapsulate_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ class NGraphEncapsulateImpl {
const ng::element::Type& ng_element_type, const ng::Shape& ng_shape,
std::shared_ptr<ng::runtime::Tensor> tensor_from_pipeline);

void DumpNgFunction(const string&,
std::shared_ptr<ngraph::runtime::Executable>);

// Accessors(getters and setters) for the private data members of
// NgraphEncapsulateImpl class
// needed by
Expand Down Expand Up @@ -148,19 +151,6 @@ class NGraphEncapsulateImpl {

void ClearNgExecMap() { m_ng_exec_map.clear(); }

std::unordered_map<std::shared_ptr<ngraph::runtime::Executable>,
std::shared_ptr<ngraph::Function>>
GetNgFunctionMap() {
return m_ng_function_map;
}

void SetNgFunctionMap(
const std::shared_ptr<ngraph::runtime::Executable>& exec,
const std::shared_ptr<ngraph::Function>& function) {
m_ng_function_map[exec] = function;
}

void ClearNgFunctionMap() { m_ng_function_map.clear(); }
// TODO:sindhu have another get function for output_cache which is only
// readable
std::vector<std::pair<void*, shared_ptr<ng::runtime::Tensor>>>&
Expand All @@ -179,6 +169,10 @@ class NGraphEncapsulateImpl {

void ClearNgExecOutputCache() { m_ng_exec_output_cache_map.clear(); }

void ClearNgExecSerializedFunctionCache() {
m_serialized_ng_function_map.clear();
}

NGraphFreshnessTracker* GetNgraphFreshnessTracker() {
return m_freshness_tracker;
}
Expand Down Expand Up @@ -236,9 +230,8 @@ class NGraphEncapsulateImpl {
// ng_function, ng_executable, Output and Input Cache maps
std::unordered_map<std::string, std::shared_ptr<ngraph::runtime::Executable>>
m_ng_exec_map;
std::unordered_map<std::shared_ptr<ngraph::runtime::Executable>,
std::shared_ptr<ngraph::Function>>
m_ng_function_map;
std::unordered_map<std::shared_ptr<ngraph::runtime::Executable>, std::string>
m_serialized_ng_function_map;

NgFunctionIOCache m_ng_exec_input_cache_map;
NgFunctionIOCache m_ng_exec_output_cache_map;
Expand Down
13 changes: 5 additions & 8 deletions ngraph_bridge/ngraph_encapsulate_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ NGraphEncapsulateOp::~NGraphEncapsulateOp() {
ng_encap_impl.ClearNgExecInputCache();
ng_encap_impl.ClearNgExecOutputCache();
ng_encap_impl.ClearNgExecMap();
ng_encap_impl.ClearNgFunctionMap();
ng_encap_impl.ClearNgExecPipelinedTensorMap();
ng_encap_impl.ClearNgExecSerializedFunctionCache();

// Release the backend
NGRAPH_VLOG(2) << "~NGraphEncapsulateOp():: ReleaseBackend";
Expand Down Expand Up @@ -291,7 +291,6 @@ void NGraphEncapsulateOp::Compute(OpKernelContext* ctx) {

std::vector<TensorShape> input_shapes;
std::vector<const Tensor*> static_input_map;
std::shared_ptr<ngraph::Function> ng_function;
std::shared_ptr<ngraph::runtime::Executable> ng_exec;
ng::runtime::Backend* op_backend;

Expand Down Expand Up @@ -526,19 +525,17 @@ void NGraphEncapsulateOp::Compute(OpKernelContext* ctx) {
try {
ng_exec->call(ng_outputs, ng_inputs);
} catch (const std::exception& exp) {
ng_function = ng_encap_impl.GetNgFunctionMap()[ng_exec];
BackendManager::UnlockBackend(ng_encap_impl.GetOpBackend());
NgraphSerialize("tf_function_error_" + ctx->op_kernel().name() + ".json",
ng_function);
ng_encap_impl.DumpNgFunction(
"tf_function_error_" + ctx->op_kernel().name() + ".json", ng_exec);
OP_REQUIRES(ctx, false,
errors::Internal(
"Caught exception while executing nGraph computation: ",
exp.what(), "\n"));
} catch (...) {
ng_function = ng_encap_impl.GetNgFunctionMap()[ng_exec];
BackendManager::UnlockBackend(ng_encap_impl.GetOpBackend());
NgraphSerialize("tf_function_error_" + ctx->op_kernel().name() + ".json",
ng_function);
ng_encap_impl.DumpNgFunction(
"tf_function_error_" + ctx->op_kernel().name() + ".json", ng_exec);
OP_REQUIRES(
ctx, false,
errors::Internal("Error in executing the nGraph computation\n"));
Expand Down
8 changes: 6 additions & 2 deletions ngraph_bridge/ngraph_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,16 @@ Status CheckAxisDimInRange(std::vector<int64> axes, size_t rank) {
void NgraphSerialize(const std::string& file_name,
const std::shared_ptr<ngraph::Function>& ng_function) {
NGRAPH_VLOG(0) << "Serializing graph to: " << file_name << std::endl;
std::string js = ngraph::serialize(ng_function, 4);
int json_indentation = 4;
StringToFile(file_name, ngraph::serialize(ng_function, json_indentation));
}

void StringToFile(const std::string& file_name, const std::string& contents) {
std::ofstream f;
f.exceptions(std::ofstream::failbit | std::ofstream::badbit);
try {
f.open(file_name);
f << js;
f << contents;
f.close();
} catch (std::ofstream::failure& e) {
NGRAPH_VLOG(0) << "Exception opening/closing file " << file_name
Expand Down
2 changes: 2 additions & 0 deletions ngraph_bridge/ngraph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ Status CheckAxisDimInRange(std::vector<int64> axes, size_t rank);
void NgraphSerialize(const std::string&,
const std::shared_ptr<ngraph::Function>&);

void StringToFile(const std::string&, const std::string&);

// Collect the total memory usage through /proc/self/stat
void MemoryProfile(long&, long&);

Expand Down
2 changes: 1 addition & 1 deletion test/model_level_tests/models/MLP/getting_repo_ready.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pip install -U keras
pip install -U keras==2.2.5
1 change: 1 addition & 0 deletions test/model_level_tests/models/MLP/repo.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
https://github.com/keras-team/keras.git
2.2.5
62 changes: 62 additions & 0 deletions test/python/test_ngraph_serialize_flag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# ==============================================================================
# Copyright 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pytest for a simple run on model testing framework

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import pytest
import platform
import os

import tensorflow as tf
import numpy as np
import re

from common import NgraphTest
import ngraph_bridge


class TestNgraphSerialize(NgraphTest):

def test_ng_serialize_to_json(self):
initial_contents = set(os.listdir())
xshape = (3, 4, 5)
x = tf.placeholder(tf.float32, shape=xshape)
out = tf.nn.l2_loss(tf.abs(x))
values = np.random.rand(*xshape)

config = ngraph_bridge.update_config(tf.ConfigProto())
ngraph_enable_serialize = os.environ.pop('NGRAPH_ENABLE_SERIALIZE',
None)
os.environ['NGRAPH_ENABLE_SERIALIZE'] = '1'
ngraph_bridge.enable()
with tf.Session(config=config) as sess:
out = sess.run((out), feed_dict={x: values})
os.environ.pop('NGRAPH_ENABLE_SERIALIZE', None)
if ngraph_enable_serialize is not None:
os.environ['NGRAPH_ENABLE_SERIALIZE'] = \
ngraph_enable_serialize

final_contents = set(os.listdir())
assert (len(final_contents) - len(initial_contents) == 1)
new_files = final_contents.difference(initial_contents)
flname = new_files.pop()
assert (flname.startswith('tf_function_') and flname.endswith('json'))
os.remove(flname)