diff --git a/tfjs-node/binding/tfjs_backend.cc b/tfjs-node/binding/tfjs_backend.cc index e536140138f..9086d9e0c8a 100644 --- a/tfjs-node/binding/tfjs_backend.cc +++ b/tfjs-node/binding/tfjs_backend.cc @@ -1218,4 +1218,13 @@ napi_value TFJSBackend::RunSavedModel(napi_env env, return output_tensor_infos; } +napi_value TFJSBackend::GetNumOfSavedModels(napi_env env) { + napi_status nstatus; + napi_value num_saved_models; + nstatus = + napi_create_int32(env, tf_savedmodel_map_.size(), &num_saved_models); + ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr); + return num_saved_models; +} + } // namespace tfnodejs diff --git a/tfjs-node/binding/tfjs_backend.h b/tfjs-node/binding/tfjs_backend.h index c4c8402d983..2f7baa0edc4 100644 --- a/tfjs-node/binding/tfjs_backend.h +++ b/tfjs-node/binding/tfjs_backend.h @@ -80,6 +80,9 @@ class TFJSBackend { napi_value input_op_names, napi_value output_op_names); + // Get number of loaded SavedModel in the backend: + napi_value GetNumOfSavedModels(napi_env env); + private: TFJSBackend(napi_env env); ~TFJSBackend(); diff --git a/tfjs-node/binding/tfjs_binding.cc b/tfjs-node/binding/tfjs_binding.cc index 6bd85f178ee..88dfacc9738 100644 --- a/tfjs-node/binding/tfjs_binding.cc +++ b/tfjs-node/binding/tfjs_binding.cc @@ -224,6 +224,11 @@ static napi_value RunSavedModel(napi_env env, napi_callback_info info) { return gBackend->RunSavedModel(env, args[0], args[1], args[2], args[3]); } +static napi_value GetNumOfSavedModels(napi_env env, napi_callback_info info) { + // Delete SavedModel takes 0 param; + return gBackend->GetNumOfSavedModels(env); +} + static napi_value InitTFNodeJSBinding(napi_env env, napi_value exports) { napi_status nstatus; @@ -255,6 +260,8 @@ static napi_value InitTFNodeJSBinding(napi_env env, napi_value exports) { napi_default, nullptr}, {"isUsingGpuDevice", nullptr, IsUsingGPUDevice, nullptr, nullptr, nullptr, napi_default, nullptr}, + {"getNumOfSavedModels", nullptr, GetNumOfSavedModels, nullptr, nullptr, + nullptr, napi_default, nullptr}, }; nstatus = napi_define_properties(env, exports, ARRAY_SIZE(exports_properties), exports_properties); diff --git a/tfjs-node/src/node.ts b/tfjs-node/src/node.ts index f8cc9de50c7..e0e0fb5a37d 100644 --- a/tfjs-node/src/node.ts +++ b/tfjs-node/src/node.ts @@ -21,7 +21,7 @@ import {tensorBoard} from './callbacks'; import {decodeBmp, decodeGif, decodeImage, decodeJpeg, decodePng, encodeJpeg, encodePng} from './image'; -import {getMetaGraphsFromSavedModel, loadSavedModel} from './saved_model'; +import {getMetaGraphsFromSavedModel, getNumOfSavedModels, loadSavedModel} from './saved_model'; import {summaryFileWriter} from './tensorboard'; export const node = { @@ -35,5 +35,6 @@ export const node = { summaryFileWriter, tensorBoard, getMetaGraphsFromSavedModel, + getNumOfSavedModels, loadSavedModel }; diff --git a/tfjs-node/src/nodejs_kernel_backend.ts b/tfjs-node/src/nodejs_kernel_backend.ts index cca6fd2e229..0545aee0fdc 100644 --- a/tfjs-node/src/nodejs_kernel_backend.ts +++ b/tfjs-node/src/nodejs_kernel_backend.ts @@ -25,6 +25,7 @@ import {isArray, isNullOrUndefined} from 'util'; import {Int64Scalar} from './int64_tensors'; import {TensorMetadata, TFEOpAttr, TFJSBinding} from './tfjs_binding'; + type TensorData = { shape: number[], dtype: number, @@ -1962,6 +1963,10 @@ export class NodeJSKernelBackend extends KernelBackend { const elapsed = process.hrtime(start); return {kernelMs: elapsed[0] * 1000 + elapsed[1] / 1000000}; } + + getNumOfSavedModels() { + return this.binding.getNumOfSavedModels(); + } } /** Returns an instance of the Node.js backend. */ diff --git a/tfjs-node/src/saved_model.ts b/tfjs-node/src/saved_model.ts index 5b0bff7e157..afdbc17877b 100644 --- a/tfjs-node/src/saved_model.ts +++ b/tfjs-node/src/saved_model.ts @@ -437,3 +437,9 @@ function mapTFDtypeToJSDtype(tfDtype: string): DataType { throw new Error('Unsupported tensor DataType: ' + tfDtype); } } + +export function getNumOfSavedModels() { + ensureTensorflowBackend(); + const backend = nodeBackend(); + return backend.getNumOfSavedModels(); +} diff --git a/tfjs-node/src/saved_model_test.ts b/tfjs-node/src/saved_model_test.ts index 99fcaceddc5..d0a9b903fe7 100644 --- a/tfjs-node/src/saved_model_test.ts +++ b/tfjs-node/src/saved_model_test.ts @@ -233,6 +233,7 @@ describe('SavedModel', () => { }); it('load TFSavedModel and delete', async () => { + expect(tf.node.getNumOfSavedModels()).toBe(0); const loadSavedModelMetaGraphSpy = spyOn(nodeBackend(), 'loadSavedModelMetaGraph').and.callThrough(); const deleteSavedModelSpy = @@ -244,9 +245,11 @@ describe('SavedModel', () => { 'serving_default'); expect(loadSavedModelMetaGraphSpy).toHaveBeenCalledTimes(1); expect(deleteSavedModelSpy).toHaveBeenCalledTimes(0); + expect(tf.node.getNumOfSavedModels()).toBe(1); model.dispose(); expect(loadSavedModelMetaGraphSpy).toHaveBeenCalledTimes(1); expect(deleteSavedModelSpy).toHaveBeenCalledTimes(1); + expect(tf.node.getNumOfSavedModels()).toBe(0); }); it('delete TFSavedModel multiple times throw exception', async done => { @@ -265,6 +268,7 @@ describe('SavedModel', () => { it('load multiple signatures from the same metagraph only call binding once', async () => { + expect(tf.node.getNumOfSavedModels()).toBe(0); const backend = nodeBackend(); const loadSavedModelMetaGraphSpy = spyOn(backend, 'loadSavedModelMetaGraph').and.callThrough(); @@ -273,13 +277,17 @@ describe('SavedModel', () => { './test_objects/saved_model/module_with_multiple_signatures', ['serve'], 'serving_default'); expect(loadSavedModelMetaGraphSpy).toHaveBeenCalledTimes(1); + expect(tf.node.getNumOfSavedModels()).toBe(1); const model2 = await tf.node.loadSavedModel( './test_objects/saved_model/module_with_multiple_signatures', ['serve'], 'timestwo'); expect(loadSavedModelMetaGraphSpy).toHaveBeenCalledTimes(1); + expect(tf.node.getNumOfSavedModels()).toBe(1); model1.dispose(); + expect(tf.node.getNumOfSavedModels()).toBe(1); model2.dispose(); expect(loadSavedModelMetaGraphSpy).toHaveBeenCalledTimes(1); + expect(tf.node.getNumOfSavedModels()).toBe(0); }); it('load signature after delete call binding', async () => { @@ -309,12 +317,12 @@ describe('SavedModel', () => { }); it('throw error when input tensors do not match input ops', async done => { + const model = await tf.node.loadSavedModel( + './test_objects/saved_model/times_three_float', ['serve'], + 'serving_default'); + const input1 = tf.tensor1d([1.0, 2, 3]); + const input2 = tf.tensor1d([1.0, 2, 3]); try { - const model = await tf.node.loadSavedModel( - './test_objects/saved_model/times_three_float', ['serve'], - 'serving_default'); - const input1 = tf.tensor1d([1.0, 2, 3]); - const input2 = tf.tensor1d([1.0, 2, 3]); model.predict([input1, input2]); done.fail(); } catch (error) { @@ -322,6 +330,7 @@ describe('SavedModel', () => { .toBe( 'Length of input op names (1) does not match the ' + 'length of input tensors (2).'); + model.dispose(); done(); } }); @@ -369,11 +378,11 @@ describe('SavedModel', () => { }); it('execute model with wrong tensor name', async done => { + const model = await tf.node.loadSavedModel( + './test_objects/saved_model/times_three_float', ['serve'], + 'serving_default'); + const input = tf.tensor1d([1.0, 2, 3]); try { - const model = await tf.node.loadSavedModel( - './test_objects/saved_model/times_three_float', ['serve'], - 'serving_default'); - const input = tf.tensor1d([1.0, 2, 3]); model.predict({'xyz': input}); done.fail(); } catch (error) { @@ -381,6 +390,7 @@ describe('SavedModel', () => { .toBe( 'The model signatureDef input names are x, however ' + 'the provided input names are xyz.'); + model.dispose(); done(); } }); @@ -466,4 +476,20 @@ describe('SavedModel', () => { test_util.expectArraysClose(await output2.data(), [1, 2, 3]); model.dispose(); }); + + it('load multiple models', async () => { + expect(tf.node.getNumOfSavedModels()).toBe(0); + const model1 = await tf.node.loadSavedModel( + './test_objects/saved_model/module_with_multiple_signatures', ['serve'], + 'serving_default'); + expect(tf.node.getNumOfSavedModels()).toBe(1); + const model2 = await tf.node.loadSavedModel( + './test_objects/saved_model/model_multi_output', ['serve'], + 'serving_default'); + expect(tf.node.getNumOfSavedModels()).toBe(2); + model1.dispose(); + expect(tf.node.getNumOfSavedModels()).toBe(1); + model2.dispose(); + expect(tf.node.getNumOfSavedModels()).toBe(0); + }); }); diff --git a/tfjs-node/src/tfjs_binding.ts b/tfjs-node/src/tfjs_binding.ts index cb64126376e..a01c46c9310 100644 --- a/tfjs-node/src/tfjs_binding.ts +++ b/tfjs-node/src/tfjs_binding.ts @@ -60,6 +60,8 @@ export interface TFJSBinding { savedModelId: number, inputTensorIds: number[], inputOpNames: string, outputOpNames: string): TensorMetadata[]; + getNumOfSavedModels(): number; + isUsingGpuDevice(): boolean; // TF Types