Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tfjs-core/src/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import * as test_util from './test_util';
import * as util from './util';
import {version} from './version';

export {InferenceModel, MetaGraph, MetaGraphInfo, ModelPredictConfig, ModelTensorInfo, SavedModelTensorInfo, SignatureDef, SignatureDefInfo} from './model_types';
export {InferenceModel, MetaGraph, MetaGraphInfo, ModelPredictConfig, ModelTensorInfo, SavedModelTensorInfo, SignatureDef, SignatureDefEntry, SignatureDefInfo} from './model_types';
// Optimizers.
export {AdadeltaOptimizer} from './optimizers/adadelta_optimizer';
export {AdagradOptimizer} from './optimizers/adagrad_optimizer';
Expand Down
15 changes: 11 additions & 4 deletions tfjs-core/src/model_types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ export interface ModelTensorInfo {
shape?: number[];
// Data type of the tensor.
dtype: DataType;
// TensorFlow native Data type of the tensor.
tfDtype?: string;
}

/**
Expand Down Expand Up @@ -138,12 +140,17 @@ export interface MetaGraph {
signatureDefs: SignatureDef;
}

/**
* Interface for SavedModel/GraphModel SignatureDef entry.
*/
export interface SignatureDefEntry {
inputs: {[key: string]: ModelTensorInfo};
outputs: {[key: string]: ModelTensorInfo};
}

/**
* Interface for SavedModel/GraphModel SignatureDef info.
*/
export interface SignatureDef {
[key: string]: {
inputs: {[key: string]: ModelTensorInfo};
outputs: {[key: string]: ModelTensorInfo};
};
[key: string]: SignatureDefEntry;
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
35 changes: 35 additions & 0 deletions tfjs-node/python/unint8_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python script for creating Tensorflow SavedModel with UINT8 input."""

import os

import tensorflow as tf
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import variables
from tensorflow.python.training.tracking import tracking
from tensorflow.python.saved_model.save import save

"""Test a basic model with functions to make sure functions are inlined."""
input_data = constant_op.constant(1, shape=[1], dtype=tf.uint8)
root = tracking.AutoTrackable()
root.v1 = variables.Variable(3)
root.v2 = variables.Variable(2)
root.f = def_function.function(lambda x: root.v1 * root.v2 * tf.cast(x, tf.int32))
to_save = root.f.get_concrete_function(input_data)

save_dir = os.path.join('..', 'test_objects', 'saved_model', 'uint8_multiply')
save(root, save_dir, to_save)
22 changes: 19 additions & 3 deletions tfjs-node/src/nodejs_kernel_backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/

import * as tf from '@tensorflow/tfjs';
import {backend_util, BackendTimingInfo, DataId, DataType, fill, KernelBackend, ones, Rank, rsqrt, Scalar, scalar, ScalarLike, ShapeMap, Tensor, Tensor1D, tensor1d, Tensor2D, tensor2d, Tensor3D, Tensor4D, Tensor5D, TensorInfo, tidy, util} from '@tensorflow/tfjs';
import {backend_util, BackendTimingInfo, DataId, DataType, fill, KernelBackend, ModelTensorInfo, ones, Rank, rsqrt, Scalar, scalar, ScalarLike, ShapeMap, Tensor, Tensor1D, tensor1d, Tensor2D, tensor2d, Tensor3D, Tensor4D, Tensor5D, TensorInfo, tidy, util} from '@tensorflow/tfjs';
import {isArray, isNullOrUndefined} from 'util';

import {Int64Scalar} from './int64_tensors';
Expand Down Expand Up @@ -1903,11 +1903,27 @@ export class NodeJSKernelBackend extends KernelBackend {
return this.binding.loadSavedModel(path, tags);
}

private getMappedInputTensorIds(
inputs: Tensor[], inputTensorInfos: ModelTensorInfo[]) {
const tensorIds = this.getInputTensorIds(inputs);
for (let i = 0; i < inputs.length; i++) {
if (inputTensorInfos[i] != null &&
inputTensorInfos[i].tfDtype === 'DT_UINT8') {
const data = Uint8Array.from(inputs[i].dataSync());
const inputTensorId = this.binding.createTensor(
inputs[i].shape, this.binding.TF_UINT8, data);
tensorIds[i] = inputTensorId;
}
}
return tensorIds;
}

runSavedModel(
id: number, inputs: Tensor[], inputOpNames: string[],
id: number, inputs: Tensor[], inputTensorInfos: ModelTensorInfo[],
outputOpNames: string[]): Tensor[] {
const outputMetadata = this.binding.runSavedModel(
id, this.getInputTensorIds(inputs), inputOpNames.join(','),
id, this.getMappedInputTensorIds(inputs, inputTensorInfos),
inputTensorInfos.map(info => info.name).join(','),
outputOpNames.join(','));
return outputMetadata.map(m => this.createOutputTensor(m));
}
Expand Down
83 changes: 42 additions & 41 deletions tfjs-node/src/saved_model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
* =============================================================================
*/

import {DataType, InferenceModel, MetaGraph, ModelPredictConfig, ModelTensorInfo, NamedTensorMap, SignatureDef, Tensor, util} from '@tensorflow/tfjs';
import {DataType, InferenceModel, MetaGraph, ModelPredictConfig, ModelTensorInfo, NamedTensorMap, SignatureDef, SignatureDefEntry, Tensor, util} from '@tensorflow/tfjs';
import * as fs from 'fs';
import {promisify} from 'util';

import {ensureTensorflowBackend, nodeBackend, NodeJSKernelBackend} from './nodejs_kernel_backend';

const readFile = promisify(fs.readFile);
Expand Down Expand Up @@ -88,8 +89,8 @@ export async function getMetaGraphsFromSavedModel(path: string):
// Get SavedModel proto message
const modelMessage = await readSavedModelProto(path);

// A SavedModel might have multiple MetaGraphs, identified by tags. Each
// MetaGraph also has it's own signatureDefs.
// A SavedModel might have multiple MetaGraphs, identified by tags.
// Each MetaGraph also has it's own signatureDefs.
const metaGraphList = modelMessage.getMetaGraphsList();
for (let i = 0; i < metaGraphList.length; i++) {
const metaGraph = {} as MetaGraph;
Expand Down Expand Up @@ -124,9 +125,10 @@ export async function getMetaGraphsFromSavedModel(path: string):
}
const inputTensor = inputsMapMessage.get(inputsMapKey.value);
const inputTensorInfo = {} as ModelTensorInfo;
inputTensorInfo.dtype = mapTFDtypeToJSDtype(
getEnumKeyFromValue(messages.DataType, inputTensor.getDtype()));

const dtype =
getEnumKeyFromValue(messages.DataType, inputTensor.getDtype());
inputTensorInfo.dtype = mapTFDtypeToJSDtype(dtype);
inputTensorInfo.tfDtype = dtype;
inputTensorInfo.name = inputTensor.getName();
inputTensorInfo.shape = inputTensor.getTensorShape().getDimList();
inputs[inputsMapKey.value] = inputTensorInfo;
Expand All @@ -143,8 +145,10 @@ export async function getMetaGraphsFromSavedModel(path: string):
}
const outputTensor = outputsMapMessage.get(outputsMapKey.value);
const outputTensorInfo = {} as ModelTensorInfo;
outputTensorInfo.dtype = mapTFDtypeToJSDtype(
getEnumKeyFromValue(messages.DataType, outputTensor.getDtype()));
const dtype =
getEnumKeyFromValue(messages.DataType, outputTensor.getDtype());
outputTensorInfo.dtype = mapTFDtypeToJSDtype(dtype);
outputTensorInfo.tfDtype = dtype;
outputTensorInfo.name = outputTensor.getName();
outputTensorInfo.shape = outputTensor.getTensorShape().getDimList();
outputs[outputsMapKey.value] = outputTensorInfo;
Expand All @@ -160,39 +164,24 @@ export async function getMetaGraphsFromSavedModel(path: string):
}

/**
* Get input and output node names from SavedModel metagraphs info. The
* input.output node names will be used when executing a SavedModel signature.
* Get SignatureDefEntry from SavedModel metagraphs info. The SignatureDefEntry
* will be used when executing a SavedModel signature.
*
* @param savedModelInfo The MetaGraphInfo array loaded through
* getMetaGraphsFromSavedModel().
* @param tags The tags of the MetaGraph to get input/output node names from.
* @param signature The signature to get input/output node names from.
*/
export function getInputAndOutputNodeNameFromMetaGraphInfo(
savedModelInfo: MetaGraph[], tags: string[], signature: string) {
export function getSignatureDefEntryFromMetaGraphInfo(
savedModelInfo: MetaGraph[], tags: string[],
signature: string): SignatureDefEntry {
for (let i = 0; i < savedModelInfo.length; i++) {
const metaGraphInfo = savedModelInfo[i];
if (stringArraysHaveSameElements(tags, metaGraphInfo.tags)) {
if (metaGraphInfo.signatureDefs[signature] == null) {
throw new Error('The SavedModel does not have signature: ' + signature);
}
const inputNodeNames: {[key: string]: string} = {};
const outputNodeNames: {[key: string]: string} = {};
for (const signatureDef of Object.keys(metaGraphInfo.signatureDefs)) {
if (signatureDef === signature) {
for (const tensorName of Object.keys(
metaGraphInfo.signatureDefs[signature].inputs)) {
inputNodeNames[tensorName] =
metaGraphInfo.signatureDefs[signature].inputs[tensorName].name;
}
for (const tensorName of Object.keys(
metaGraphInfo.signatureDefs[signature].outputs)) {
outputNodeNames[tensorName] =
metaGraphInfo.signatureDefs[signature].outputs[tensorName].name;
}
}
}
return [inputNodeNames, outputNodeNames];
return metaGraphInfo.signatureDefs[signature];
}
}
throw new Error(`The SavedModel does not have tags: ${tags}`);
Expand All @@ -206,11 +195,10 @@ export function getInputAndOutputNodeNameFromMetaGraphInfo(
*/
export class TFSavedModel implements InferenceModel {
private disposed = false;

private outputNodeNames_: {[key: string]: string};
constructor(
private sessionId: number, private jsid: number,
private inputNodeNames: {[key: string]: string},
private outputNodeNames: {[key: string]: string},
private signature: SignatureDefEntry,
private backend: NodeJSKernelBackend) {}

/**
Expand Down Expand Up @@ -254,6 +242,19 @@ export class TFSavedModel implements InferenceModel {
}
}

get outputNodeNames() {
if (this.outputNodeNames_ != null) {
return this.outputNodeNames_;
}
this.outputNodeNames_ =
Object.keys(this.signature.outputs)
.reduce((names: {[key: string]: string}, key: string) => {
names[key] = this.signature.outputs[key].name;
return names;
}, {});
return this.outputNodeNames_;
}

/**
* Execute the inference for the input tensors.
*
Expand Down Expand Up @@ -287,27 +288,27 @@ export class TFSavedModel implements InferenceModel {
if (inputs instanceof Tensor) {
inputTensors.push(inputs);
const result = this.backend.runSavedModel(
this.sessionId, inputTensors, Object.values(this.inputNodeNames),
this.sessionId, inputTensors, Object.values(this.signature.inputs),
Object.values(this.outputNodeNames));
return result.length > 1 ? result : result[0];
} else if (Array.isArray(inputs)) {
inputTensors = inputs;
return this.backend.runSavedModel(
this.sessionId, inputTensors, Object.values(this.inputNodeNames),
this.sessionId, inputTensors, Object.values(this.signature.inputs),
Object.values(this.outputNodeNames));
} else {
const inputTensorNames = Object.keys(this.inputNodeNames);
const inputTensorNames = Object.keys(this.signature.inputs);
const providedInputNames = Object.keys(inputs);
if (!stringArraysHaveSameElements(
inputTensorNames, providedInputNames)) {
throw new Error(`The model signatureDef input names are ${
inputTensorNames.join()}, however the provided input names are ${
providedInputNames.join()}.`);
}
const inputNodeNamesArray = [];
const inputNodeNamesArray: ModelTensorInfo[] = [];
for (let i = 0; i < inputTensorNames.length; i++) {
inputTensors.push(inputs[inputTensorNames[i]]);
inputNodeNamesArray.push(this.inputNodeNames[inputTensorNames[i]]);
inputNodeNamesArray.push(this.signature.inputs[inputTensorNames[i]]);
}
const outputTensorNames = Object.keys(this.outputNodeNames);
const outputNodeNamesArray = [];
Expand Down Expand Up @@ -387,9 +388,8 @@ export async function loadSavedModel(
const backend = nodeBackend();

const savedModelInfo = await getMetaGraphsFromSavedModel(path);
const [inputNodeNames, outputNodeNames] =
getInputAndOutputNodeNameFromMetaGraphInfo(
savedModelInfo, tags, signature);
const signatureDefEntry =
getSignatureDefEntryFromMetaGraphInfo(savedModelInfo, tags, signature);

let sessionId: number;

Expand All @@ -407,7 +407,7 @@ export async function loadSavedModel(
}
const id = nextTFSavedModelId++;
const savedModel =
new TFSavedModel(sessionId, id, inputNodeNames, outputNodeNames, backend);
new TFSavedModel(sessionId, id, signatureDefEntry, backend);
loadedSavedModelPathMap.set(id, {path, tags, sessionId});
return savedModel;
}
Expand All @@ -431,6 +431,7 @@ function mapTFDtypeToJSDtype(tfDtype: string): DataType {
case 'DT_FLOAT':
return 'float32';
case 'DT_INT32':
case 'DT_UINT8':
return 'int32';
case 'DT_BOOL':
return 'bool';
Expand Down
21 changes: 16 additions & 5 deletions tfjs-node/src/saved_model_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import {NamedTensorMap, test_util} from '@tensorflow/tfjs';
import * as tf from './index';
import {nodeBackend} from './nodejs_kernel_backend';
import {getEnumKeyFromValue, getInputAndOutputNodeNameFromMetaGraphInfo, readSavedModelProto} from './saved_model';
import {getEnumKeyFromValue, getSignatureDefEntryFromMetaGraphInfo, readSavedModelProto} from './saved_model';

// tslint:disable-next-line:no-require-imports
const messages = require('./proto/api_pb');
Expand Down Expand Up @@ -187,11 +187,11 @@ describe('SavedModel', () => {
it('get input and output node names from SavedModel metagraphs', async () => {
const modelInfo = await tf.node.getMetaGraphsFromSavedModel(
'./test_objects/saved_model/times_three_float');
const inputAndOutputNodeNames = getInputAndOutputNodeNameFromMetaGraphInfo(
const signature = getSignatureDefEntryFromMetaGraphInfo(
modelInfo, ['serve'], 'serving_default');
expect(inputAndOutputNodeNames.length).toBe(2);
expect(inputAndOutputNodeNames[0]['x']).toBe('serving_default_x:0');
expect(inputAndOutputNodeNames[1]['output_0'])
expect(Object.keys(signature).length).toBe(2);
expect(signature.inputs['x'].name).toBe('serving_default_x:0');
expect(signature.outputs['output_0'].name)
.toBe('StatefulPartitionedCall:0');
});

Expand Down Expand Up @@ -394,6 +394,17 @@ describe('SavedModel', () => {
done();
}
});
it('execute model with uint8 input', async () => {
const model = await tf.node.loadSavedModel(
'./test_objects/saved_model/uint8_multiply', ['serve'],
'serving_default');
const input = tf.scalar(3, 'int32');
const output = model.predict(input) as tf.Tensor;
expect(output.shape).toEqual([]);
expect(output.dtype).toBe('int32');
test_util.expectArraysClose(await output.data(), [18]);
model.dispose();
});

it('execute model int times two', async () => {
const model = await tf.node.loadSavedModel(
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.