Skip to content

Commit

Permalink
[tfjs-node] Add SavedModel signatureDef loading and deleting API (#2217)
Browse files Browse the repository at this point in the history
FEATURE
* add proto_pb.js and test it

* sync gpu

* update

* fix nit

* update gpu

* remove gpu cloud test

* save

* save

* save

* update signatureDefInfo

* fix nit

* add comments

* fix lint

* manage session

* update test

* fix nit

* add doc

* update doc

* address comment

* fix nit

* print input/output node names

* array to string

* address comments

* missing .

* manual copy api_pb.js

* manual copy api_pb.js in gpu

* mkdir -p

* add test objects in gpu

* save

* use ModelTensorInfo from core
  • Loading branch information
Kangyi Zhang committed Oct 28, 2019
1 parent 1eec198 commit 10ba224
Show file tree
Hide file tree
Showing 19 changed files with 626 additions and 76 deletions.
2 changes: 1 addition & 1 deletion tfjs-node-gpu/package.json
Expand Up @@ -12,7 +12,7 @@
"node": ">=8.11.0"
},
"scripts": {
"build": "tsc",
"build": "tsc && mkdir -p dist/proto && cp src/proto/api_pb.js dist/proto/api_pb.js",
"build-npm": "yarn prep-gpu && ./scripts/build-npm.sh",
"build-addon": "./scripts/build-and-upload-addon.sh",
"build-addon-from-source": "node-pre-gyp install --build-from-source",
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
94 changes: 88 additions & 6 deletions tfjs-node/binding/tfjs_backend.cc
Expand Up @@ -17,16 +17,15 @@

#include "tfjs_backend.h"

#include "napi_auto_ref.h"
#include "tf_auto_tensor.h"
#include "tfe_auto_op.h"
#include "utils.h"

#include <algorithm>
#include <cstring>
#include <memory>
#include <set>
#include <string>
#include "napi_auto_ref.h"
#include "tf_auto_tensor.h"
#include "tfe_auto_op.h"
#include "utils.h"

namespace tfnodejs {

Expand Down Expand Up @@ -695,7 +694,8 @@ void AssignOpAttr(napi_env env, TFE_Op *tfe_op, napi_value attr_value) {
}
}

TFJSBackend::TFJSBackend(napi_env env) : next_tensor_id_(0) {
TFJSBackend::TFJSBackend(napi_env env)
: next_tensor_id_(0), next_savedmodel_id_(0) {
TF_AutoStatus tf_status;
TFE_ContextOptions *tfe_options = TFE_NewContextOptions();
tfe_context_ = TFE_NewContext(tfe_options, tf_status.status);
Expand Down Expand Up @@ -746,6 +746,10 @@ TFJSBackend::~TFJSBackend() {
for (auto &kv : tfe_handle_map_) {
TFE_DeleteTensorHandle(kv.second);
}
for (auto &kv : tf_savedmodel_map_) {
TF_AutoStatus tf_status;
TF_DeleteSession(kv.second, tf_status.status);
}
if (tfe_context_ != nullptr) {
TFE_DeleteContext(tfe_context_);
}
Expand All @@ -758,6 +762,12 @@ int32_t TFJSBackend::InsertHandle(TFE_TensorHandle *tfe_handle) {
.first->first;
}

int32_t TFJSBackend::InsertSavedModel(TF_Session *tf_session) {
return tf_savedmodel_map_
.insert(std::make_pair(next_savedmodel_id_++, tf_session))
.first->first;
}

napi_value TFJSBackend::CreateTensor(napi_env env, napi_value shape_value,
napi_value dtype_value,
napi_value array_value) {
Expand Down Expand Up @@ -950,4 +960,76 @@ napi_value TFJSBackend::ExecuteOp(napi_env env, napi_value op_name_value,
return output_tensor_infos;
}

napi_value TFJSBackend::LoadSavedModel(napi_env env,
napi_value export_dir_value,
napi_value tags_value) {
TF_SessionOptions *session_options = TF_NewSessionOptions();

TF_Buffer *run_options = TF_NewBufferFromString("", 0);

std::string export_dir_string;
napi_status nstatus;
nstatus = GetStringParam(env, export_dir_value, export_dir_string);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
const char *export_dir = export_dir_string.c_str();

std::string tags;
nstatus = GetStringParam(env, tags_value, tags);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);

std::vector<const char *> tags_ptrs = splitStringByComma(tags);

TF_Graph *graph = TF_NewGraph();

TF_Buffer *metagraph = TF_NewBuffer();

TF_AutoStatus tf_status;

TF_Session *session = TF_LoadSessionFromSavedModel(
session_options, run_options, export_dir, tags_ptrs.data(),
tags_ptrs.size(), graph, metagraph, tf_status.status);
// Delete objects that are necessary when loading the SavedModel but not gonna
// be used later.
TF_DeleteSessionOptions(session_options);
TF_DeleteBuffer(run_options);
TF_DeleteBuffer(metagraph);
TF_DeleteGraph(graph);

if (TF_GetCode(tf_status.status) != TF_OK) {
NAPI_THROW_ERROR(env, "Failed to load SavedModel: %s",
TF_Message(tf_status.status));
return nullptr;
}

napi_value output_session_id;
nstatus =
napi_create_int32(env, InsertSavedModel(session), &output_session_id);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
return output_session_id;
}

void TFJSBackend::DeleteSavedModel(napi_env env,
napi_value savedmodel_id_value) {
int32_t savedmodel_id;
ENSURE_NAPI_OK(
env, napi_get_value_int32(env, savedmodel_id_value, &savedmodel_id));

auto savedmodel_entry = tf_savedmodel_map_.find(savedmodel_id);
if (savedmodel_entry == tf_savedmodel_map_.end()) {
NAPI_THROW_ERROR(
env, "Delete called on a SavedModel not referenced (savedmodel_id: %d)",
savedmodel_id);
return;
}

TF_AutoStatus tf_status;
TF_DeleteSession(savedmodel_entry->second, tf_status.status);
if (TF_GetCode(tf_status.status) != TF_OK) {
NAPI_THROW_ERROR(env, "Failed to delete SavedModel: %s",
TF_Message(tf_status.status));
return;
}
tf_savedmodel_map_.erase(savedmodel_entry);
}

} // namespace tfnodejs
22 changes: 18 additions & 4 deletions tfjs-node/binding/tfjs_backend.h
Expand Up @@ -22,6 +22,7 @@
#include <map>
#include <memory>
#include <string>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"

namespace tfnodejs {
Expand All @@ -30,7 +31,7 @@ class TFJSBackend {
public:
// Creates, initializes, and returns a TFJSBackend instance. If initialization
// fails, a nullptr is returned.
static TFJSBackend* Create(napi_env env);
static TFJSBackend *Create(napi_env env);

// Creates a new Tensor with given shape and data and returns an ID that
// refernces the new Tensor.
Expand Down Expand Up @@ -59,15 +60,28 @@ class TFJSBackend {
napi_value op_attr_inputs, napi_value input_tensor_ids,
napi_value num_output_values);

// Load a SavedModel from a path:
// - export_dir (string)
// - tags_value (string)
napi_value LoadSavedModel(napi_env env, napi_value export_dir,
napi_value tags_value);

// Delete the SavedModel corresponding TF_Session and TF_Graph
// - saved_model_id (number)
void DeleteSavedModel(napi_env env, napi_value saved_model_id);

private:
TFJSBackend(napi_env env);
~TFJSBackend();

int32_t InsertHandle(TFE_TensorHandle* tfe_handle);
int32_t InsertHandle(TFE_TensorHandle *tfe_handle);
int32_t InsertSavedModel(TF_Session *tf_session);

TFE_Context* tfe_context_;
std::map<int32_t, TFE_TensorHandle*> tfe_handle_map_;
TFE_Context *tfe_context_;
std::map<int32_t, TFE_TensorHandle *> tfe_handle_map_;
std::map<int32_t, TF_Session *> tf_savedmodel_map_;
int32_t next_tensor_id_;
int32_t next_savedmodel_id_;
std::string device_name;

public:
Expand Down
78 changes: 71 additions & 7 deletions tfjs-node/binding/tfjs_binding.cc
Expand Up @@ -21,10 +21,10 @@

namespace tfnodejs {

TFJSBackend* gBackend = nullptr;
TFJSBackend *gBackend = nullptr;

static void AssignIntProperty(napi_env env, napi_value exports,
const char* name, int32_t value) {
const char *name, int32_t value) {
napi_value js_value;
napi_status nstatus = napi_create_int32(env, value, &js_value);
ENSURE_NAPI_OK(env, nstatus);
Expand All @@ -47,7 +47,10 @@ static napi_value CreateTensor(napi_env env, napi_callback_info info) {
ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);

if (argc < 3) {
NAPI_THROW_ERROR(env, "Invalid number of args passed to createTensor()");
NAPI_THROW_ERROR(env,
"Invalid number of args passed to createTensor(). "
"Expecting 3 args but got %d.",
argc);
return nullptr;
}

Expand Down Expand Up @@ -76,7 +79,10 @@ static napi_value DeleteTensor(napi_env env, napi_callback_info info) {
ENSURE_NAPI_OK_RETVAL(env, nstatus, js_this);

if (argc < 1) {
NAPI_THROW_ERROR(env, "Invalid number of args passed to deleteTensor()");
NAPI_THROW_ERROR(env,
"Invalid number of args passed to deleteTensor(). "
"Expecting 1 arg but got %d.",
argc);
return js_this;
}

Expand All @@ -97,7 +103,10 @@ static napi_value TensorDataSync(napi_env env, napi_callback_info info) {
ENSURE_NAPI_OK_RETVAL(env, nstatus, js_this);

if (argc < 1) {
NAPI_THROW_ERROR(env, "Invalid number of args passed to tensorDataSync()");
NAPI_THROW_ERROR(env,
"Invalid number of args passed to tensorDataSync(). "
"Expecting 1 arg but got %d.",
argc);
return nullptr;
}

Expand All @@ -109,7 +118,7 @@ static napi_value TensorDataSync(napi_env env, napi_callback_info info) {
static napi_value ExecuteOp(napi_env env, napi_callback_info info) {
napi_status nstatus;

// Create tensor takes 3 params: op-name, op-attrs, input-tensor-ids,
// Create tensor takes 4 params: op-name, op-attrs, input-tensor-ids,
// num-outputs:
size_t argc = 4;
napi_value args[4];
Expand All @@ -118,7 +127,10 @@ static napi_value ExecuteOp(napi_env env, napi_callback_info info) {
ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);

if (argc < 4) {
NAPI_THROW_ERROR(env, "Invalid number of args passed to executeOp()");
NAPI_THROW_ERROR(env,
"Invalid number of args passed to executeOp(). Expecting "
"4 args but got %d.",
argc);
return nullptr;
}

Expand All @@ -140,6 +152,54 @@ static napi_value IsUsingGPUDevice(napi_env env, napi_callback_info info) {
return result;
}

static napi_value LoadSavedModel(napi_env env, napi_callback_info info) {
napi_status nstatus;

// Load saved model takes 2 params: export_dir, tags:
size_t argc = 2;
napi_value args[2];
napi_value js_this;
nstatus = napi_get_cb_info(env, info, &argc, args, &js_this, nullptr);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);

if (argc < 2) {
NAPI_THROW_ERROR(env,
"Invalid number of args passed to LoadSavedModel(). "
"Expecting 2 args but got %d.",
argc);
return nullptr;
}

ENSURE_VALUE_IS_STRING_RETVAL(env, args[0], nullptr);
ENSURE_VALUE_IS_STRING_RETVAL(env, args[1], nullptr);

return gBackend->LoadSavedModel(env, args[0], args[1]);
}

static napi_value DeleteSavedModel(napi_env env, napi_callback_info info) {
napi_status nstatus;

// Delete SavedModel takes 1 param: savedModel ID;
size_t argc = 1;
napi_value args[1];
napi_value js_this;
nstatus = napi_get_cb_info(env, info, &argc, args, &js_this, nullptr);
ENSURE_NAPI_OK_RETVAL(env, nstatus, js_this);

if (argc < 1) {
NAPI_THROW_ERROR(env,
"Invalid number of args passed to deleteSavedModel(). "
"Expecting 1 arg but got %d.",
argc);
return js_this;
}

ENSURE_VALUE_IS_NUMBER_RETVAL(env, args[0], js_this);

gBackend->DeleteSavedModel(env, args[0]);
return js_this;
}

static napi_value InitTFNodeJSBinding(napi_env env, napi_value exports) {
napi_status nstatus;

Expand All @@ -161,6 +221,10 @@ static napi_value InitTFNodeJSBinding(napi_env env, napi_value exports) {
napi_default, nullptr},
{"executeOp", nullptr, ExecuteOp, nullptr, nullptr, nullptr, napi_default,
nullptr},
{"loadSavedModel", nullptr, LoadSavedModel, nullptr, nullptr, nullptr,
napi_default, nullptr},
{"deleteSavedModel", nullptr, DeleteSavedModel, nullptr, nullptr, nullptr,
napi_default, nullptr},
{"TF_Version", nullptr, nullptr, nullptr, nullptr, tf_version,
napi_default, nullptr},
{"isUsingGpuDevice", nullptr, IsUsingGPUDevice, nullptr, nullptr, nullptr,
Expand Down

0 comments on commit 10ba224

Please sign in to comment.