Skip to content

Commit

Permalink
Merge pull request #152 from zkmkarlsruhe/master
Browse files Browse the repository at this point in the history
Frozen graph support
  • Loading branch information
serizba authored Feb 11, 2022
2 parents 4be8da9 + 30ceacb commit 1b6cbde
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 9 deletions.
13 changes: 13 additions & 0 deletions examples/load_frozen_graph/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
cmake_minimum_required(VERSION 3.10)
project(example)

find_library(TENSORFLOW_LIB tensorflow HINT $ENV{HOME}/libtensorflow2/lib)

set(CMAKE_CXX_STANDARD 17)
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer -fsanitize=address")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer")
set(CMAKE_LINKER_FLAGS "${CMAKE_LINKER_FLAGS} -lasan")

add_executable(example main.cpp)
target_include_directories(example PRIVATE ../../include $ENV{HOME}/libtensorflow2/include)
target_link_libraries (example "${TENSORFLOW_LIB}")
22 changes: 22 additions & 0 deletions examples/load_frozen_graph/create_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import (
convert_variables_to_constants_v2,
)

input = tf.keras.Input(shape=(5,))
output = tf.keras.layers.Dense(5, activation=tf.nn.relu)(input)
output = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)(output)
model = tf.keras.Model(inputs=input, outputs=output)

# Create frozen graph
x = tf.TensorSpec(model.input_shape, tf.float32, name="x")
concrete_function = tf.function(lambda x: model(x)).get_concrete_function(x)
frozen_model = convert_variables_to_constants_v2(concrete_function)

# Check input/output node name
print(f"{frozen_model.inputs=}")
print(f"{frozen_model.outputs=}")

# Save the graph as protobuf format
directory = "."
tf.io.write_graph(frozen_model.graph, directory, "model.pb", as_text=False)
22 changes: 22 additions & 0 deletions examples/load_frozen_graph/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include <iostream>

#include "cppflow/ops.h"
#include "cppflow/model.h"


int main() {

auto input = cppflow::fill({10, 5}, 1.0f);
std::cout << "start" << std::endl;
cppflow::model model("../model.pb", cppflow::model::FROZEN_GRAPH);
auto output = model({{"x:0", input}}, {{"Identity:0"}})[0];

std::cout << output << std::endl;

auto values = output.get_data<float>();

for (auto v : values) {
std::cout << v << std::endl;
}
return 0;
}
Binary file added examples/load_frozen_graph/model.pb
Binary file not shown.
81 changes: 72 additions & 9 deletions include/cppflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@ namespace cppflow {

class model {
public:
explicit model(const std::string& filename);
enum TYPE
{
SAVED_MODEL,
FROZEN_GRAPH,
};

explicit model(const std::string& filename, const TYPE type=TYPE::SAVED_MODEL);

std::vector<std::string> get_operations() const;
std::vector<int64_t> get_operation_shape(const std::string& operation) const;
Expand All @@ -34,6 +40,7 @@ namespace cppflow {
model &operator=(model &&other) = default;

private:
TF_Buffer * readGraph(const std::string& filename);

std::shared_ptr<TF_Graph> graph;
std::shared_ptr<TF_Session> session;
Expand All @@ -43,24 +50,44 @@ namespace cppflow {

namespace cppflow {

inline model::model(const std::string &filename) {
inline model::model(const std::string &filename, const TYPE type) {
this->graph = {TF_NewGraph(), TF_DeleteGraph};

// Create the session.
std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)> session_options = {TF_NewSessionOptions(), TF_DeleteSessionOptions};
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> run_options = {TF_NewBufferFromString("", 0), TF_DeleteBuffer};
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> meta_graph = {TF_NewBuffer(), TF_DeleteBuffer};

auto session_deleter = [](TF_Session* sess) {
TF_DeleteSession(sess, context::get_status());
status_check(context::get_status());
};

int tag_len = 1;
const char* tag = "serve";
this->session = {TF_LoadSessionFromSavedModel(session_options.get(), run_options.get(), filename.c_str(),
&tag, tag_len, this->graph.get(), meta_graph.get(), context::get_status()),
session_deleter};
if (type == TYPE::SAVED_MODEL) {
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> run_options = {TF_NewBufferFromString("", 0), TF_DeleteBuffer};
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> meta_graph = {TF_NewBuffer(), TF_DeleteBuffer};

int tag_len = 1;
const char* tag = "serve";
this->session = {TF_LoadSessionFromSavedModel(session_options.get(), run_options.get(), filename.c_str(),
&tag, tag_len, this->graph.get(), meta_graph.get(), context::get_status()),
session_deleter};
}
else if (type == TYPE::FROZEN_GRAPH) {
this->session = {TF_NewSession(this->graph.get(), session_options.get(), context::get_status()), session_deleter};
status_check(context::get_status());

// Import the graph definition
TF_Buffer* def = readGraph(filename);
if(def == nullptr) {
throw std::runtime_error("Failed to import graph def from file");
}

std::unique_ptr<TF_ImportGraphDefOptions, decltype(&TF_DeleteImportGraphDefOptions)> graph_opts = {TF_NewImportGraphDefOptions(), TF_DeleteImportGraphDefOptions};
TF_GraphImportGraphDef(this->graph.get(), def, graph_opts.get(), context::get_status());
TF_DeleteBuffer(def);
}
else {
throw std::runtime_error("Model type unknown");
}

status_check(context::get_status());
}
Expand Down Expand Up @@ -169,6 +196,42 @@ namespace cppflow {
inline tensor model::operator()(const tensor& input) {
return (*this)({{"serving_default_input_1", input}}, {"StatefulPartitionedCall"})[0];
}


inline TF_Buffer * model::readGraph(const std::string& filename) {
std::ifstream file (filename, std::ios::binary | std::ios::ate);

// Error opening the file
if (!file.is_open()) {
std::cerr << "Unable to open file: " << filename << std::endl;
return nullptr;
}

// Cursor is at the end to get size
auto size = file.tellg();
// Move cursor to the beginning
file.seekg (0, std::ios::beg);

// Read
auto data = std::make_unique<char[]>(size);
file.seekg (0, std::ios::beg);
file.read (data.get(), size);

// Error reading the file
if (!file) {
std::cerr << "Unable to read the full file: " << filename << std::endl;
return nullptr;
}

// Create tensorflow buffer from read data
TF_Buffer* buffer = TF_NewBufferFromString(data.get(), size);

// Close file and remove data
file.close();

return buffer;
}

}

#endif //CPPFLOW2_MODEL_H

0 comments on commit 1b6cbde

Please sign in to comment.