Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tfjs-node/binding/tfjs_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1218,4 +1218,13 @@ napi_value TFJSBackend::RunSavedModel(napi_env env,
return output_tensor_infos;
}

napi_value TFJSBackend::GetNumOfSavedModels(napi_env env) {
napi_status nstatus;
napi_value num_saved_models;
nstatus =
napi_create_int32(env, tf_savedmodel_map_.size(), &num_saved_models);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
return num_saved_models;
}

} // namespace tfnodejs
3 changes: 3 additions & 0 deletions tfjs-node/binding/tfjs_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ class TFJSBackend {
napi_value input_op_names,
napi_value output_op_names);

// Get number of loaded SavedModel in the backend:
napi_value GetNumOfSavedModels(napi_env env);

private:
TFJSBackend(napi_env env);
~TFJSBackend();
Expand Down
7 changes: 7 additions & 0 deletions tfjs-node/binding/tfjs_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,11 @@ static napi_value RunSavedModel(napi_env env, napi_callback_info info) {
return gBackend->RunSavedModel(env, args[0], args[1], args[2], args[3]);
}

static napi_value GetNumOfSavedModels(napi_env env, napi_callback_info info) {
// Delete SavedModel takes 0 param;
return gBackend->GetNumOfSavedModels(env);
}

static napi_value InitTFNodeJSBinding(napi_env env, napi_value exports) {
napi_status nstatus;

Expand Down Expand Up @@ -255,6 +260,8 @@ static napi_value InitTFNodeJSBinding(napi_env env, napi_value exports) {
napi_default, nullptr},
{"isUsingGpuDevice", nullptr, IsUsingGPUDevice, nullptr, nullptr, nullptr,
napi_default, nullptr},
{"getNumOfSavedModels", nullptr, GetNumOfSavedModels, nullptr, nullptr,
nullptr, napi_default, nullptr},
};
nstatus = napi_define_properties(env, exports, ARRAY_SIZE(exports_properties),
exports_properties);
Expand Down
3 changes: 2 additions & 1 deletion tfjs-node/src/node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import {tensorBoard} from './callbacks';
import {decodeBmp, decodeGif, decodeImage, decodeJpeg, decodePng, encodeJpeg, encodePng} from './image';
import {getMetaGraphsFromSavedModel, loadSavedModel} from './saved_model';
import {getMetaGraphsFromSavedModel, getNumOfSavedModels, loadSavedModel} from './saved_model';
import {summaryFileWriter} from './tensorboard';

export const node = {
Expand All @@ -35,5 +35,6 @@ export const node = {
summaryFileWriter,
tensorBoard,
getMetaGraphsFromSavedModel,
getNumOfSavedModels,
loadSavedModel
};
5 changes: 5 additions & 0 deletions tfjs-node/src/nodejs_kernel_backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {isArray, isNullOrUndefined} from 'util';

import {Int64Scalar} from './int64_tensors';
import {TensorMetadata, TFEOpAttr, TFJSBinding} from './tfjs_binding';

type TensorData = {
shape: number[],
dtype: number,
Expand Down Expand Up @@ -1962,6 +1963,10 @@ export class NodeJSKernelBackend extends KernelBackend {
const elapsed = process.hrtime(start);
return {kernelMs: elapsed[0] * 1000 + elapsed[1] / 1000000};
}

getNumOfSavedModels() {
return this.binding.getNumOfSavedModels();
}
}

/** Returns an instance of the Node.js backend. */
Expand Down
6 changes: 6 additions & 0 deletions tfjs-node/src/saved_model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,9 @@ function mapTFDtypeToJSDtype(tfDtype: string): DataType {
throw new Error('Unsupported tensor DataType: ' + tfDtype);
}
}

export function getNumOfSavedModels() {
ensureTensorflowBackend();
const backend = nodeBackend();
return backend.getNumOfSavedModels();
}
44 changes: 35 additions & 9 deletions tfjs-node/src/saved_model_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ describe('SavedModel', () => {
});

it('load TFSavedModel and delete', async () => {
expect(tf.node.getNumOfSavedModels()).toBe(0);
const loadSavedModelMetaGraphSpy =
spyOn(nodeBackend(), 'loadSavedModelMetaGraph').and.callThrough();
const deleteSavedModelSpy =
Expand All @@ -244,9 +245,11 @@ describe('SavedModel', () => {
'serving_default');
expect(loadSavedModelMetaGraphSpy).toHaveBeenCalledTimes(1);
expect(deleteSavedModelSpy).toHaveBeenCalledTimes(0);
expect(tf.node.getNumOfSavedModels()).toBe(1);
model.dispose();
expect(loadSavedModelMetaGraphSpy).toHaveBeenCalledTimes(1);
expect(deleteSavedModelSpy).toHaveBeenCalledTimes(1);
expect(tf.node.getNumOfSavedModels()).toBe(0);
});

it('delete TFSavedModel multiple times throw exception', async done => {
Expand All @@ -265,6 +268,7 @@ describe('SavedModel', () => {

it('load multiple signatures from the same metagraph only call binding once',
async () => {
expect(tf.node.getNumOfSavedModels()).toBe(0);
const backend = nodeBackend();
const loadSavedModelMetaGraphSpy =
spyOn(backend, 'loadSavedModelMetaGraph').and.callThrough();
Expand All @@ -273,13 +277,17 @@ describe('SavedModel', () => {
'./test_objects/saved_model/module_with_multiple_signatures',
['serve'], 'serving_default');
expect(loadSavedModelMetaGraphSpy).toHaveBeenCalledTimes(1);
expect(tf.node.getNumOfSavedModels()).toBe(1);
const model2 = await tf.node.loadSavedModel(
'./test_objects/saved_model/module_with_multiple_signatures',
['serve'], 'timestwo');
expect(loadSavedModelMetaGraphSpy).toHaveBeenCalledTimes(1);
expect(tf.node.getNumOfSavedModels()).toBe(1);
model1.dispose();
expect(tf.node.getNumOfSavedModels()).toBe(1);
model2.dispose();
expect(loadSavedModelMetaGraphSpy).toHaveBeenCalledTimes(1);
expect(tf.node.getNumOfSavedModels()).toBe(0);
});

it('load signature after delete call binding', async () => {
Expand Down Expand Up @@ -309,19 +317,20 @@ describe('SavedModel', () => {
});

it('throw error when input tensors do not match input ops', async done => {
const model = await tf.node.loadSavedModel(
'./test_objects/saved_model/times_three_float', ['serve'],
'serving_default');
const input1 = tf.tensor1d([1.0, 2, 3]);
const input2 = tf.tensor1d([1.0, 2, 3]);
try {
const model = await tf.node.loadSavedModel(
'./test_objects/saved_model/times_three_float', ['serve'],
'serving_default');
const input1 = tf.tensor1d([1.0, 2, 3]);
const input2 = tf.tensor1d([1.0, 2, 3]);
model.predict([input1, input2]);
done.fail();
} catch (error) {
expect(error.message)
.toBe(
'Length of input op names (1) does not match the ' +
'length of input tensors (2).');
model.dispose();
done();
}
});
Expand Down Expand Up @@ -369,18 +378,19 @@ describe('SavedModel', () => {
});

it('execute model with wrong tensor name', async done => {
const model = await tf.node.loadSavedModel(
'./test_objects/saved_model/times_three_float', ['serve'],
'serving_default');
const input = tf.tensor1d([1.0, 2, 3]);
try {
const model = await tf.node.loadSavedModel(
'./test_objects/saved_model/times_three_float', ['serve'],
'serving_default');
const input = tf.tensor1d([1.0, 2, 3]);
model.predict({'xyz': input});
done.fail();
} catch (error) {
expect(error.message)
.toBe(
'The model signatureDef input names are x, however ' +
'the provided input names are xyz.');
model.dispose();
done();
}
});
Expand Down Expand Up @@ -466,4 +476,20 @@ describe('SavedModel', () => {
test_util.expectArraysClose(await output2.data(), [1, 2, 3]);
model.dispose();
});

it('load multiple models', async () => {
expect(tf.node.getNumOfSavedModels()).toBe(0);
const model1 = await tf.node.loadSavedModel(
'./test_objects/saved_model/module_with_multiple_signatures', ['serve'],
'serving_default');
expect(tf.node.getNumOfSavedModels()).toBe(1);
const model2 = await tf.node.loadSavedModel(
'./test_objects/saved_model/model_multi_output', ['serve'],
'serving_default');
expect(tf.node.getNumOfSavedModels()).toBe(2);
model1.dispose();
expect(tf.node.getNumOfSavedModels()).toBe(1);
model2.dispose();
expect(tf.node.getNumOfSavedModels()).toBe(0);
});
});
2 changes: 2 additions & 0 deletions tfjs-node/src/tfjs_binding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ export interface TFJSBinding {
savedModelId: number, inputTensorIds: number[], inputOpNames: string,
outputOpNames: string): TensorMetadata[];

getNumOfSavedModels(): number;

isUsingGpuDevice(): boolean;

// TF Types
Expand Down